diff --git "a/tokenizer/pipeline.py" "b/tokenizer/pipeline.py" new file mode 100644--- /dev/null +++ "b/tokenizer/pipeline.py" @@ -0,0 +1,7144 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations +import os, sys, importlib +from typing import Optional, Dict, List +import torch +from functools import partial +from diffusers import DiffusionPipeline +from diffusers.utils import logging +from accelerate import ( + init_empty_weights, + infer_auto_device_map, + load_checkpoint_and_dispatch, +) +from huggingface_hub import snapshot_download +from tqdm import tqdm +from copy import deepcopy +import random + +import cv2 +import numpy as np +from torchvision import transforms +from torchvision.transforms import functional as F +from torchvision.transforms import InterpolationMode + +from dataclasses import dataclass +from types import SimpleNamespace + +from einops import rearrange +from torch import Tensor, nn +from safetensors.torch import load_file as load_sft + +import copy +from typing import List, Tuple, Optional + +import torch.nn.functional as F +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel +from dataclasses import asdict, fields +from diffusers.models.modeling_utils import ModelMixin +from diffusers.configuration_utils import ConfigMixin +import math + +from transformers.activations import ACT2FN + +from torch import nn +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.nn.attention.flex_attention import flex_attention +from torch.nn.functional import scaled_dot_product_attention +from transformers.utils import ModelOutput + +from flash_attn import flash_attn_varlen_func + +torch._dynamo.config.cache_size_limit = 512 +torch._dynamo.config.accumulated_cache_size_limit = 4096 +# flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune' +flex_attention = torch.compile(flex_attention) +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging +from typing import List, Optional, Tuple, Union +import torch.utils.checkpoint +from torch import nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from typing import Optional, Tuple +from transformers.tokenization_utils import AddedToken +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +import json + +import unicodedata +from functools import lru_cache +import regex as re +from transformers.tokenization_utils import PreTrainedTokenizer + +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +import string +import warnings +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from transformers.convert_slow_tokenizer import import_protobuf +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import AddedToken + + +if TYPE_CHECKING: + from transformers.tokenization_utils_base import TextInput +from transformers.utils import logging, requires_backends + + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +SPIECE_UNDERLINE = "▁" + +from typing import Dict, List, Optional, Union + +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from transformers.image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + +logger = logging.get_logger(__name__) + +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn.init import _calculate_fan_in_and_fan_out + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import BaseModelOutputWithPooling, ImageClassifierOutput +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, + torch_int, +) +from typing import List, Optional, Union +from transformers.feature_extraction_utils import BatchFeature +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from transformers.utils import TensorType +from PIL import Image +from torch.nn.attention.flex_attention import or_masks, and_masks + + +def create_sparse_mask(document_lens, split_lens, attn_modes, device): + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + def full_and_noise_mask(b, h, q_idx, kv_idx): + return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0) + + def remove_noise_mask(b, h, q_idx, kv_idx): + return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx]))) + + def sample_mask(b, h, q_idx, kv_idx): + return document_id[q_idx] == document_id[kv_idx] + + full_and_noise_tmp = [] + noise_tmp = [] + + for i, (length, model) in enumerate(zip(split_lens, attn_modes)): + value = i if model in ['full', 'noise'] else -1 + full_and_noise_tmp.extend([value] * length) + value_noise = i if model == 'noise' else -1 + noise_tmp.extend([value_noise] * length) + + full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device) + noise_seq_id = torch.Tensor(noise_tmp).to(device) + + document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device) + + return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask) + + +def patchify(image, patch_size): + p = patch_size + c, h, w = image.shape + assert h % p == 0 and w % p == 0 + image = image.reshape(c, h // p, p, w // p, p) + image = torch.einsum("chpwq->hwpqc", image) + image = image.reshape(-1, p**2 * c) + return image + + +def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): + num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size + coords_h = torch.arange(0, num_patches_h) + coords_w = torch.arange(0, num_patches_w) + pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() + return pos_ids + + +def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side): + num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size + boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten() + return pos_ids + + +def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"): + """ + nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within + a sample, where each sample contains multiple splits with different attn modes. + nested_attn_modes: whether to use full attn in each split. + """ + sample_len = sum(split_lens) + attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device) + + csum = 0 + for s, attn_mode in zip(split_lens, attn_modes): + assert attn_mode in ['causal', 'full', 'noise'] + if attn_mode == "causal": + attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril() + attention_mask[csum:csum + s, :csum] = 1 + else: + attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s)) + attention_mask[csum:csum + s, :csum] = 1 + csum += s + + csum = 0 + for s, attn_mode in zip(split_lens, attn_modes): + if attn_mode == "noise": + attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s)) + attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s)) + csum += s + + attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_( + ~attention_mask, float("-inf") + ) + + return attention_mask + + +def split_integer_exp_decay(S, ng_sample_decay=1.0): + if ng_sample_decay == 1.0: + N = random.randint(1, S) + else: + base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S)) + p = [base * math.pow(ng_sample_decay, i) for i in range(S)] + N = random.choices(list(range(1, S + 1)), p, k=1)[0] + cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S] + result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)] + return result, cumsum + + +def pil_img2rgb(image): + if image.mode == "RGBA" or image.info.get("transparency", None) is not None: + image = image.convert("RGBA") + white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255)) + white.paste(image, mask=image.split()[3]) + image = white + else: + image = image.convert("RGB") + + return image + + +def add_special_tokens(tokenizer): + all_special_tokens = [] + for k, v in tokenizer.special_tokens_map.items(): + if isinstance(v, str): + all_special_tokens.append(v) + elif isinstance(v, list): + all_special_tokens += v + + new_tokens = [] + + if '<|im_start|>' not in all_special_tokens: + new_tokens.append('<|im_start|>') + + if '<|im_end|>' not in all_special_tokens: + new_tokens.append('<|im_end|>') + + if '<|vision_start|>' not in all_special_tokens: + new_tokens.append('<|vision_start|>') + + if '<|vision_end|>' not in all_special_tokens: + new_tokens.append('<|vision_end|>') + + num_new_tokens = tokenizer.add_tokens(new_tokens) + bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>') + eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>') + start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>') + end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>') + + new_token_ids = dict( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + start_of_image=start_of_image, + end_of_image=end_of_image, + ) + + return tokenizer, new_token_ids, num_new_tokens + + +def len2weight(x, loss_reduction='square'): + if x == 0: + return x + if loss_reduction == 'token': + return 1 + if loss_reduction == 'sample': + return 1 / x + if loss_reduction == 'square': + return 1 / (x ** 0.5) + raise NotImplementedError(loss_reduction) + +class NaiveCache: + def __init__(self, num_layers): + self.key_cache = {k: None for k in range(num_layers)} + self.value_cache = {k: None for k in range(num_layers)} + + @property + def num_layers(self): + return len(self.key_cache) + + @property + def seq_lens(self): + if self.key_cache[0] is not None: + return self.key_cache[0].shape[0] + else: + return 0 + + +class _Qwen2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, _Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = _Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + is_causal=True, + _attn_implementation="flash_attention_2", + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.is_causal = is_causal + self._attn_implementation = _attn_implementation + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B" +_CONFIG_FOR_DOC = "_Qwen2Config" + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 +class Qwen2RotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[_Qwen2Config] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: _Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = config.is_causal + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2FlashAttention2(Qwen2Attention): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +QWEN2_ATTENTION_CLASSES = { + "eager": Qwen2Attention, + "flash_attention_2": Qwen2FlashAttention2, +} + + +QWEN2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`_Qwen2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = _Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", + "tokenizer_file": "tokenizer.json", +} + + +MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768} + + +PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + +@lru_cache() +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Qwen2Tokenizer(PreTrainedTokenizer): + """ + Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding. + + Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import Qwen2Tokenizer + + >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer") + >>> tokenizer("Hello world")["input_ids"] + [9707, 1879] + + >>> tokenizer(" Hello world")["input_ids"] + [21927, 1879] + ``` + This is expected. + + You should not use GPT2Tokenizer instead, because of the different pretokenization rules. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*): + The beginning of sequence token. Not applicable for this tokenizer. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding, for example when batching sequences of different lengths. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not the model should cleanup the spaces that were added when splitting the input text during the + tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. The default behavior is + to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") = + ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<', + '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + clean_up_tokenization_spaces=False, + split_special_tokens=False, + **kwargs, + ): + # Qwen vocab does not contain control tokens; added tokens need to be special + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(pad_token, str) + else pad_token + ) + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_merges = [] + with open(merges_file, encoding="utf-8") as merges_handle: + for i, line in enumerate(merges_handle): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + bpe_merges.append(tuple(line.split())) + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + # NOTE: the cache can grow without bound and will get really large for long running processes + # (esp. for texts of language that do not use space between word, e.g. Chinese); technically + # not a memory leak but appears as one. + # GPT2Tokenizer has the same problem, so let's be consistent. + self.cache = {} + + self.pat = re.compile(PRETOKENIZE_REGEX) + + if kwargs.get("add_prefix_space", False): + logger.warning_once( + f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." + ) + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def decode( + self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers + # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer + return super().decode( + token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def prepare_for_tokenization(self, text, **kwargs): + text = unicodedata.normalize("NFC", text) + return (text, kwargs) + + + +class SiglipTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a + Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + + Example: + + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipTextConfig() + + >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_text_model" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class _SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import _SiglipVisionConfig, SiglipVisionModel + + >>> # Initializing a _SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = _SiglipVisionConfig() + + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipConfig(PretrainedConfig): + r""" + [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to + instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`_SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import SiglipConfig, SiglipModel + + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipConfig() + + >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a _SiglipVisionConfig + >>> from transformers import SiglipTextConfig, _SiglipVisionConfig + + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = _SiglipVisionConfig() + + >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip" + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `_SiglipVisionConfig` with default values.") + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = _SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: _SiglipVisionConfig, **kwargs): + r""" + Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision + model configuration. + + Returns: + [`SiglipConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +if is_vision_available(): + import PIL + + +class SiglipImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: bool = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + height, width = size["height"], size["width"] + images = [ + resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + + +# General docstring +_CONFIG_FOR_DOC = "SiglipConfig" +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + +class SiglipSdpaAttention(SiglipAttention): + """ + Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + is_causal = False + + # Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if self.is_causal and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "SiglipTextEmbeddings", + "SiglipEncoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, SiglipForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SIGLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, embed_dim) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: _SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings(SIGLIP_START_DOCSTRING) +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise TypeError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, _SiglipVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type _SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention implementation + text_model = SiglipTextModel._from_config(text_config) + vision_model = SiglipVisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Union[Tuple, SiglipOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = ( + torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp() + + self.logit_bias + ) + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@add_start_docstrings( + """ + SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of + the patch tokens) e.g. for ImageNet. + """, + SIGLIP_START_DOCSTRING, +) +class SiglipForImageClassification(SiglipPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: SiglipConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = SiglipVisionModel._from_config(config.vision_config) + self.vision_model = vision_model.vision_model + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, SiglipForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `SiglipModel` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vision_model( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + sequence_output = outputs[0] + + # average pool the patch tokens + sequence_output = torch.mean(sequence_output, dim=1) + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class SiglipProcessor(ProcessorMixin): + r""" + Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor. + + [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the + [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information. + + Args: + image_processor ([`SiglipImageProcessor`]): + The image processor is a required input. + tokenizer ([`SiglipTokenizer`]): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "SiglipImageProcessor" + tokenizer_class = "SiglipTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: int = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` argument to + SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer( + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + ) + + if images is not None: + image_features = self.image_processor(images, return_tensors=return_tensors) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features.pixel_values + return encoding + elif text is not None: + return encoding + else: + return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +class SiglipTokenizer(PreTrainedTokenizer): + """ + Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + model_max_length (`int`, *optional*, defaults to 64): + The maximum length (in number of tokens) for model inputs. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + additional_special_tokens=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + model_max_length=64, + do_lower_case=True, + **kwargs, + ) -> None: + requires_backends(self, "protobuf") + + pad_token = ( + AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(pad_token, str) + else pad_token + ) + unk_token = ( + AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(unk_token, str) + else unk_token + ) + eos_token = ( + AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True) + if isinstance(eos_token, str) + else eos_token + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.do_lower_case = do_lower_case + self.vocab_file = vocab_file + + self.sp_model = self.get_spm_processor() + self.vocab_file = vocab_file + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + model_max_length=model_max_length, + do_lower_case=do_lower_case, + **kwargs, + ) + + def get_spm_processor(self): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf() + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size + def vocab_size(self): + return self.sp_model.get_piece_size() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def remove_punctuation(self, text: str) -> str: + return text.translate(str.maketrans("", "", string.punctuation)) + + # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 + def canonicalize_text(self, text, *, keep_punctuation_exact_string=None): + """Returns canonicalized `text` (puncuation removed). + + Args: + text (`str`): + String to be canonicalized. + keep_punctuation_exact_string (`str`, *optional*): + If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}' + (but will still remove '{' and '}' that appear separately). + """ + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string) + ) + else: + text = self.remove_punctuation(text) + text = re.sub(r"\s+", " ", text) + text = text.strip() + + return text + + def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. + """ + tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + @property + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. + + For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`. + + Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + text = self.canonicalize_text(text, keep_punctuation_exact_string=None) + tokens = self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +class BagelConfig(PretrainedConfig): + def __init__( + self, + visual_gen=True, + visual_und=True, + llm_config=None, + vit_config=None, + vae_config=None, + latent_patch_size=2, + max_latent_size=32, + vit_max_num_patch_per_side=70, + connector_act="gelu_pytorch_tanh", + interpolate_pos=False, + timestep_shift=1.0, + **kwargs + ): + super().__init__(**kwargs) + self.visual_gen = visual_gen + self.visual_und = visual_und + self.llm_config = llm_config + self.vit_config = vit_config + self.vae_config = vae_config + self.latent_patch_size = latent_patch_size + self.max_latent_size = max_latent_size + self.vit_max_num_patch_per_side = vit_max_num_patch_per_side + self.connector_act = connector_act + self.interpolate_pos = interpolate_pos + self.timestep_shift = timestep_shift + + +class Bagel(PreTrainedModel): + config_class = BagelConfig + base_model_prefix = 'bagel' + + def __init__( + self, + config: BagelConfig, # ← first! + language_model: Optional[Qwen2ForCausalLM] = None, + vit_model: Optional[SiglipVisionModel] = None, + ): + if isinstance(config.llm_config, dict): + config.llm_config = Qwen2Config(**config.llm_config) + if isinstance(config.vit_config, dict): + config.vit_config = SiglipVisionConfig(**config.vit_config) + if isinstance(config.vae_config, dict): # ← NEW + config.vae_config = SimpleNamespace(**config.vae_config) + + if language_model is None or vit_model is None: + with init_empty_weights(): # ‘meta’ device → 0 RAM + language_model = Qwen2ForCausalLM(config.llm_config) + vit_model = SiglipVisionModel(config.vit_config) + + super().__init__(config) + self.language_model = language_model + self.hidden_size = config.llm_config.hidden_size + self.use_moe = "Mo" in config.llm_config.layer_module + self.num_heads = config.llm_config.num_attention_heads + + if config.visual_gen: + self.latent_patch_size = config.latent_patch_size + self.timestep_shift = config.timestep_shift + self.latent_downsample = config.vae_config.downsample * config.latent_patch_size + self.max_latent_size = config.max_latent_size + self.latent_channel = config.vae_config.z_channels + self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel + self.time_embedder = TimestepEmbedder(self.hidden_size) + self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size) + self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim) + self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size) + + if config.visual_und: + self.vit_model = vit_model + self.vit_patch_size = config.vit_config.patch_size + self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side + self.vit_hidden_size = config.vit_config.hidden_size + self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act) + self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size) + self.vit_model.vision_model.embeddings.convert_conv2d_to_linear(config.vit_config, meta=True) + if config.interpolate_pos: + self.get_flattened_position_ids = get_flattened_position_ids_interpolate + else: + self.get_flattened_position_ids = get_flattened_position_ids_extrapolate + + self.config = config + self._init_weights() + + def _init_weights(self): + if self.config.visual_gen: + nn.init.constant_(self.llm2vae.weight, 0) + nn.init.constant_(self.llm2vae.bias, 0) + + def forward( + self, + sequence_length: int, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + sample_lens: List[int], + packed_position_ids: torch.LongTensor, + nested_attention_masks: List[torch.Tensor] = None, + split_lens: List[int] = None, + attn_modes: List[str] = None, + # for visual understanding + ce_loss_indexes: Optional[torch.BoolTensor] = None, + packed_label_ids: Optional[torch.LongTensor] = None, + packed_vit_tokens: Optional[torch.Tensor] = None, + packed_vit_token_indexes: Optional[torch.LongTensor] = None, + packed_vit_position_ids: Optional[torch.LongTensor] = None, + vit_token_seqlens: Optional[torch.IntTensor] = None, + # for visual generation + padded_latent: Optional[torch.Tensor] = None, + patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None, + packed_latent_position_ids: Optional[torch.LongTensor] = None, + packed_vae_token_indexes: Optional[torch.LongTensor] = None, + packed_timesteps: Optional[torch.LongTensor] = None, + mse_loss_indexes: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: + """ + Args: + sequence_length: length of sequence. + packed_text_ids: 1-D int tensor, packed text token ids. + packed_text_indexes: 1-D int tensor, packed text token indexes in sequence. + sample_lens: A list of N ints, length of each sample in packed_sequence. + nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and + -inf means ignore. + packed_position_ids: packed 1-D positions, an image has only one global position shared + by all latent tokens. + + packed_vit_tokens: packed patchified image tokens for vit model. + packed_vit_position_ids: 1-D int tensor, the position of each token for vit model. + packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence. + vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model. + packed_label_ids: 1-D int tensor, packed label token ids. + ce_loss_indexes: 1-D bool tensor, where to compute ce loss. + + padded_latent: padded latent from VAE encoder. + patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image. + packed_latent_position_ids: 1-D int tensor, the position of each token for latent. + packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence. + packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image. + mse_loss_indexes: 1-D bool tensor, where to compute mse loss. + """ + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + if nested_attention_masks is None: + sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device) + seqlen = sum(sample_lens) + block_mask = create_block_mask( + sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen, + device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True + ) + attention_mask = block_mask + else: + attention_mask = nested_attention_masks + + if self.config.visual_und: + cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) + cu_seqlens = cu_seqlens.to(torch.int32) + max_seqlen = torch.max(vit_token_seqlens).item() + packed_vit_token_embed = self.vit_model( + packed_pixel_values=packed_vit_tokens, + packed_flattened_position_ids=packed_vit_position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + packed_vit_token_embed = self.connector(packed_vit_token_embed) + vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids) + packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb + packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed + + if self.config.visual_gen: + p = self.latent_patch_size + packed_latent = [] + for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes): + latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p) + latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel) + packed_latent.append(latent) + packed_latent_clean = torch.cat(packed_latent, dim=0) + + noise = torch.randn_like(packed_latent_clean) + packed_timesteps = torch.sigmoid(packed_timesteps) + packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps) + packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise + packed_timestep_embeds = self.time_embedder(packed_timesteps) + latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids) + packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb + packed_sequence[packed_vae_token_indexes] = packed_latent + + extra_inputs = {} + if self.use_moe: + packed_und_token_indexes = packed_text_indexes + if packed_vit_token_indexes is not None: + packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0) + extra_inputs.update( + packed_und_token_indexes=packed_und_token_indexes, + packed_gen_token_indexes=packed_vae_token_indexes, + ) + + last_hidden_state = self.language_model( + packed_sequence=packed_sequence, + sample_lens=sample_lens, + attention_mask=attention_mask, + packed_position_ids=packed_position_ids, + **extra_inputs, + ) + + mse = None + if self.config.visual_gen: + packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes]) + target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise + has_mse = packed_timesteps > 0 + mse = (packed_mse_preds - target[has_mse]) ** 2 + + ce = None + if ce_loss_indexes is not None: + packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes]) + ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none") + + return dict(mse=mse, ce=ce) + + + def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids): + packed_text_ids = list() + packed_text_position_ids = list() + text_token_lens = list() + packed_text_indexes = list() + packed_key_value_indexes = list() + + curr = 0 + newlens, new_rope = list(), list() + for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + text_ids = tokenizer.encode(prompt) + text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']] + text_token_lens.append(len(text_ids)) + packed_text_ids.extend(text_ids) + packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids))) + packed_text_indexes.extend(range(curr, curr + len(text_ids))) + newlens.append(curr_kvlen + len(text_ids)) + new_rope.append(curr_position_id + len(text_ids)) + curr += len(text_ids) + + generation_input = { + "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int), + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + } + + return generation_input, newlens, new_rope + + @torch.no_grad + def forward_cache_update_text( + self, + past_key_values: NaiveCache, + packed_text_ids: torch.IntTensor, + packed_text_position_ids: torch.LongTensor, + text_token_lens: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_key_value_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + + extra_inputs = {} + if self.use_moe: + extra_inputs = {"mode": "und"} + + output = self.language_model.forward_inference( + packed_query_sequence=packed_text_embedding, + query_lens=text_token_lens, + packed_query_position_ids=packed_text_position_ids, + packed_query_indexes=packed_text_indexes, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + key_values_lens=key_values_lens, + update_past_key_values=True, + is_causal=True, + **extra_inputs, + ) + past_key_values = output.past_key_values + + return past_key_values + + def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids): + packed_vit_token_indexes = list() + vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list() + packed_text_ids, packed_text_indexes = list(), list() + packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() + packed_key_value_indexes = list() + + _curr = curr = 0 + newlens, new_rope = list(), list() + for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + packed_text_ids.append(new_token_ids['start_of_image']) + packed_text_indexes.append(_curr) + packed_indexes.append(curr) + curr += 1 + _curr += 1 + + image_tensor = transforms(image) + vit_position_ids = self.get_flattened_position_ids( + image_tensor.size(1), image_tensor.size(2), + self.vit_patch_size, + max_num_patches_per_side=self.vit_max_num_patch_per_side + ) + vit_tokens = patchify(image_tensor, self.vit_patch_size) + packed_vit_tokens.append(vit_tokens) + num_img_tokens = vit_tokens.shape[0] + packed_vit_position_ids.append(vit_position_ids) + vit_token_seqlens.append(num_img_tokens) + packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens)) + packed_indexes.extend(range(curr, curr + num_img_tokens)) + curr += num_img_tokens + _curr += num_img_tokens + + packed_text_ids.append(new_token_ids['end_of_image']) + packed_text_indexes.append(_curr) + packed_indexes.append(curr) + curr += 1 + _curr += 1 + + packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) + packed_seqlens.append(num_img_tokens + 2) + newlens.append(curr_kvlen + num_img_tokens + 2) + new_rope.append(curr_position_id + 1) + + generation_input = { + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int), + "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0), + "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0), + "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long), + "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), + "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), + "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + } + + return generation_input, newlens, new_rope + + @torch.no_grad + def forward_cache_update_vit( + self, + past_key_values: NaiveCache, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_vit_tokens: torch.Tensor, + packed_vit_token_indexes: torch.LongTensor, + packed_vit_position_ids: torch.LongTensor, + vit_token_seqlens: torch.IntTensor, + packed_position_ids: torch.LongTensor, + packed_seqlens: torch.IntTensor, + packed_indexes: torch.LongTensor, + packed_key_value_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) + cu_seqlens = cu_seqlens.to(torch.int32) + max_seqlen = torch.max(vit_token_seqlens).item() + packed_vit_token_embed = self.vit_model( + packed_pixel_values=packed_vit_tokens, + packed_flattened_position_ids=packed_vit_position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + packed_vit_token_embed = self.connector(packed_vit_token_embed) + pos_emb = self.vit_pos_embed(packed_vit_position_ids) + packed_vit_token_embed = packed_vit_token_embed + pos_emb + if packed_vit_token_embed.dtype != packed_sequence.dtype: + packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype) + packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed + + extra_inputs = {} + if self.use_moe: + extra_inputs = {"mode": "und"} + + output = self.language_model.forward_inference( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + key_values_lens=key_values_lens, + update_past_key_values=True, + is_causal=False, + **extra_inputs, + ) + past_key_values = output.past_key_values + + return past_key_values + + def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0): + patchified_vae_latent_shapes, packed_vae_position_ids = list(), list() + packed_vae_token_indexes = list() + packed_text_ids, packed_text_indexes = list(), list() + packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() + packed_key_value_indexes = list() + + _curr = curr = 0 + vae_image_tensors = list() + newlens, new_rope = list(), list() + for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + packed_text_ids.append(new_token_ids['start_of_image']) + packed_text_indexes.append(_curr) + packed_indexes.append(curr) + curr += 1 + _curr += 1 + + image_tensor = transforms(image) + vae_image_tensors.append(image_tensor) + vae_posiiton_ids = self.get_flattened_position_ids( + image_tensor.size(1), image_tensor.size(2), + self.latent_downsample, + max_num_patches_per_side=self.max_latent_size + ) + packed_vae_position_ids.append(vae_posiiton_ids) + H, W = image_tensor.shape[1:] + h = H // self.latent_downsample + w = W // self.latent_downsample + patchified_vae_latent_shapes.append((h, w)) + + num_img_tokens = w * h + packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens)) + packed_indexes.extend(range(curr, curr + num_img_tokens)) + curr += num_img_tokens + _curr += num_img_tokens + + packed_text_ids.append(new_token_ids['end_of_image']) + packed_text_indexes.append(_curr) + packed_indexes.append(curr) + curr += 1 + _curr += 1 + + packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) + packed_seqlens.append(num_img_tokens + 2) + newlens.append(curr_kvlen + num_img_tokens + 2) + new_rope.append(curr_position_id + 1) + + image_sizes = [item.shape for item in vae_image_tensors] + max_image_size = [max(item) for item in list(zip(*image_sizes))] + padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size)) + for i, image_tensor in enumerate(vae_image_tensors): + padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor + + generation_input = { + "padded_images": padded_images, + "patchified_vae_latent_shapes": patchified_vae_latent_shapes, + "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), + "packed_timesteps": torch.tensor([timestep]), + "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), + "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), + "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + } + + return generation_input, newlens, new_rope + + @torch.no_grad + def forward_cache_update_vae( + self, + vae_model, + past_key_values: NaiveCache, + padded_images: torch.Tensor, + patchified_vae_latent_shapes: List, + packed_vae_position_ids: torch.LongTensor, + packed_timesteps: torch.Tensor, + packed_vae_token_indexes: torch.LongTensor, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_position_ids: torch.LongTensor, + packed_seqlens: torch.IntTensor, + packed_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + packed_key_value_indexes: torch.Tensor, + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + padded_latent = vae_model.encode(padded_images) + + p = self.latent_patch_size + packed_latent = list() + for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes): + latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p) + latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel) + packed_latent.append(latent) + packed_latent = torch.cat(packed_latent, dim=0) + packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) + packed_timestep_embeds = self.time_embedder(packed_timesteps) + packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed + if packed_latent.dtype != packed_sequence.dtype: + packed_latent = packed_latent.to(packed_sequence.dtype) + packed_sequence[packed_vae_token_indexes] = packed_latent + + extra_inputs = {} + if self.use_moe: + extra_inputs = { + "mode": "gen", + "packed_vae_token_indexes": packed_vae_token_indexes, + "packed_text_indexes": packed_text_indexes + } + + output = self.language_model.forward_inference( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=True, + is_causal=False, + **extra_inputs, + ) + past_key_values = output.past_key_values + + return past_key_values + + def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids): + packed_text_ids, packed_text_indexes = list(), list() + packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list() + packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list() + packed_key_value_indexes = list() + + query_curr = curr = 0 + for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + packed_text_ids.append(new_token_ids['start_of_image']) + packed_text_indexes.append(query_curr) + packed_indexes.append(curr) + curr += 1 + query_curr += 1 + + vae_posiiton_ids = self.get_flattened_position_ids( + H, W, + self.latent_downsample, + max_num_patches_per_side=self.max_latent_size + ) + packed_vae_position_ids.append(vae_posiiton_ids) + + h, w = H // self.latent_downsample, W // self.latent_downsample + num_image_tokens = h * w + packed_init_noises.append( + torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2) + ) + packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens)) + packed_indexes.extend(range(curr, curr + num_image_tokens)) + curr += num_image_tokens + query_curr += num_image_tokens + + packed_text_ids.append(new_token_ids['end_of_image']) + packed_text_indexes.append(query_curr) + packed_indexes.append(curr) + curr += 1 + query_curr += 1 + + packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) + packed_seqlens.append(num_image_tokens + 2) + + generation_input = { + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "packed_init_noises": torch.cat(packed_init_noises, dim=0), + "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), + "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), + "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), + "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + } + + return generation_input + + def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes): + packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list() + + query_curr = curr = 0 + for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + packed_indexes.append(curr) + curr += 1 + query_curr += 1 + + h, w = H // self.latent_downsample, W // self.latent_downsample + num_image_tokens = h * w + packed_indexes.extend(range(curr, curr + num_image_tokens)) + curr += num_image_tokens + query_curr += num_image_tokens + + packed_indexes.append(curr) + curr += 1 + query_curr += 1 + + packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) + + generation_input = { + "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), + "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long), + "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + } + + return generation_input + + @torch.no_grad + def generate_image( + self, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_init_noises: torch.Tensor, + packed_vae_position_ids: torch.LongTensor, + packed_vae_token_indexes: torch.LongTensor, + packed_seqlens: torch.IntTensor, + packed_position_ids: torch.LongTensor, + packed_indexes: torch.LongTensor, + past_key_values: NaiveCache, + key_values_lens: torch.IntTensor, + packed_key_value_indexes: torch.LongTensor, + num_timesteps: int = 24, + timestep_shift: float = 1.0, + cfg_renorm_min: float = 0.0, + cfg_renorm_type: str = "global", + cfg_interval: Optional[Tuple[float, float]] = [0, 1], + # cfg_text + cfg_text_scale: float = 1.0, + cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None, + cfg_text_packed_position_ids: Optional[torch.LongTensor] = None, + cfg_text_past_key_values: Optional[NaiveCache] = None, + cfg_text_key_values_lens: Optional[torch.IntTensor] = None, + cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None, + # cfg_img + cfg_img_scale: float = 1.0, + cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None, + cfg_img_packed_position_ids: Optional[torch.LongTensor] = None, + cfg_img_past_key_values: Optional[NaiveCache] = None, + cfg_img_key_values_lens: Optional[torch.IntTensor] = None, + cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None, + cfg_type: str = "parallel", + ): + x_t = packed_init_noises + + timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device) + timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps) + dts = timesteps[:-1] - timesteps[1:] + timesteps = timesteps[:-1] + + for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): + + timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) + if t > cfg_interval[0] and t <= cfg_interval[1]: + cfg_text_scale_ = cfg_text_scale + cfg_img_scale_ = cfg_img_scale + else: + cfg_text_scale_ = 1.0 + cfg_img_scale_ = 1.0 + v_t = self._forward_flow( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_position_ids=packed_position_ids, + packed_indexes=packed_indexes, + packed_seqlens=packed_seqlens, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + cfg_renorm_min=cfg_renorm_min, + cfg_renorm_type=cfg_renorm_type, + # cfg_text + cfg_text_scale=cfg_text_scale_, + cfg_text_packed_position_ids=cfg_text_packed_position_ids, + cfg_text_packed_query_indexes=cfg_text_packed_query_indexes, + cfg_text_key_values_lens=cfg_text_key_values_lens, + cfg_text_past_key_values=cfg_text_past_key_values, + cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes, + # cfg_img + cfg_img_scale=cfg_img_scale_, + cfg_img_packed_position_ids=cfg_img_packed_position_ids, + cfg_img_packed_query_indexes=cfg_img_packed_query_indexes, + cfg_img_key_values_lens=cfg_img_key_values_lens, + cfg_img_past_key_values=cfg_img_past_key_values, + cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes, + cfg_type=cfg_type, + ) + + x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise + + unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) + return unpacked_latent + + @torch.no_grad + def _forward_flow( + self, + x_t: torch.Tensor, + timestep: torch.LongTensor, + packed_vae_token_indexes: torch.LongTensor, + packed_vae_position_ids: torch.LongTensor, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_indexes: torch.LongTensor, + packed_position_ids: torch.LongTensor, + packed_seqlens: torch.IntTensor, + key_values_lens: torch.IntTensor, + past_key_values: NaiveCache, + packed_key_value_indexes: torch.LongTensor, + cfg_renorm_min: float = 0.0, + cfg_renorm_type: str = "global", + # cfg_text + cfg_text_scale: float = 1.0, + cfg_text_packed_position_ids: Optional[torch.LongTensor] = None, + cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None, + cfg_text_key_values_lens: Optional[torch.Tensor] = None, + cfg_text_past_key_values: Optional[NaiveCache] = None, + cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None, + # cfg_img + cfg_img_scale: float = 1.0, + cfg_img_packed_position_ids: Optional[torch.LongTensor] = None, + cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None, + cfg_img_key_values_lens: Optional[torch.Tensor] = None, + cfg_img_past_key_values: Optional[NaiveCache] = None, + cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None, + cfg_type: str = "parallel", + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + assert timestep.unique().shape[0] == 1 + packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) + packed_timestep_embeds = self.time_embedder(timestep) + x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed + if x_t.dtype != packed_sequence.dtype: + x_t = x_t.to(packed_sequence.dtype) + packed_sequence[packed_vae_token_indexes] = x_t + + extra_inputs = {} + if self.use_moe: + extra_inputs = { + "mode": "gen", + "packed_vae_token_indexes": packed_vae_token_indexes, + "packed_text_indexes": packed_text_indexes + } + + output = self.language_model.forward_inference( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + v_t = self.llm2vae(output.packed_query_sequence) + v_t = v_t[packed_vae_token_indexes] + + if cfg_text_scale > 1.0: + cfg_text_output = self.language_model.forward_inference( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=cfg_text_packed_position_ids, + packed_query_indexes=cfg_text_packed_query_indexes, + past_key_values=cfg_text_past_key_values, + key_values_lens=cfg_text_key_values_lens, + packed_key_value_indexes=cfg_text_packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence) + cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes] + + if cfg_img_scale > 1.0: + cfg_img_output = self.language_model.forward_inference( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=cfg_img_packed_position_ids, + packed_query_indexes=cfg_img_packed_query_indexes, + past_key_values=cfg_img_past_key_values, + key_values_lens=cfg_img_key_values_lens, + packed_key_value_indexes=cfg_img_packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence) + cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes] + + if cfg_text_scale > 1.0: + if cfg_renorm_type == "text_channel": + v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) + norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) + norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True) + scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + v_t_text = v_t_text_ * scale + if cfg_img_scale > 1.0: + v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t) + else: + v_t = v_t_text + else: + v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) + + if cfg_img_scale > 1.0: + v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t) + else: + v_t_ = v_t_text_ + + # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit + if cfg_renorm_type == "global": + norm_v_t = torch.norm(v_t) + norm_v_t_ = torch.norm(v_t_) + elif cfg_renorm_type == "channel": + norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) + norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True) + else: + raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted") + scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + v_t = v_t_ * scale + else: + # No CFG + pass + + return v_t + + def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids): + packed_start_tokens, packed_key_value_indexes = list(), list() + packed_query_position_ids = list() + + curr = 0 + for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + packed_start_tokens.append(new_token_ids['bos_token_id']) + packed_query_position_ids.append(curr_position_id) + curr += curr_kvlen + + generation_input = { + "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long), + "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + } + + return generation_input + + @torch.no_grad + def generate_text( + self, + past_key_values: NaiveCache, + packed_key_value_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + packed_start_tokens: torch.LongTensor, + packed_query_position_ids: torch.LongTensor, + max_length: int, + do_sample: bool = False, + temperature: float = 1.0, + end_token_id: int = None, + ): + step = 0 + generated_sequence = [] + curr_tokens = packed_start_tokens + while step < max_length: + generated_sequence.append(curr_tokens) + packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens) + query_lens = torch.ones_like(curr_tokens) + packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange( + 0, len(key_values_lens), + device=key_values_lens.device, + dtype=key_values_lens.dtype + ) + + uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) + for i in range(len(uppacked)): + uppacked[i] += i + packed_key_value_indexes = torch.cat(uppacked, dim=0) + + extra_inputs = {} + if self.use_moe: + extra_inputs = {"mode": "und"} + + output = self.language_model.forward_inference( + packed_query_sequence=packed_text_embedding, + query_lens=query_lens, + packed_query_position_ids=packed_query_position_ids, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=True, + is_causal=True, + **extra_inputs, + ) + past_key_values = output.past_key_values + packed_query_sequence = output.packed_query_sequence + pred_logits = self.language_model.lm_head(packed_query_sequence) + + if do_sample: + probs = nn.functional.softmax(pred_logits / temperature, dim=-1) + curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + curr_tokens = torch.argmax(pred_logits, dim=-1) + + uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) + for i in range(len(uppacked)): + uppacked[i] = torch.cat( + [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0 + ) + packed_key_value_indexes = torch.cat(uppacked, dim=0) + key_values_lens = key_values_lens + 1 + packed_query_position_ids = packed_query_position_ids + 1 + step += 1 + + if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1 + break + + output_device = generated_sequence[0].device + return torch.stack([i.to(output_device) for i in generated_sequence], dim=0) + + # for evaluation + @torch.no_grad() + def chat( + self, + tokenizer, + new_token_ids, + image_transform, + images, + prompt, + max_length: int, + do_sample: bool = False, + temperature: float = 1.0, + ): + device = next(self.parameters()).device + + if isinstance(new_token_ids, dict): + for k, v in new_token_ids.items(): + if torch.is_tensor(v): + new_token_ids[k] = v.to(device) + elif torch.is_tensor(new_token_ids): + new_token_ids = new_token_ids.to(device) + + # prefill + past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers) + newlens = [0] + new_rope = [0] + + # add images + for image in images: + generation_input, newlens, new_rope = self.prepare_vit_images( + curr_kvlens=newlens, + curr_rope=new_rope, + images=[image], + transforms=image_transform, + new_token_ids=new_token_ids, + ) + for k, v in generation_input.items(): + if torch.is_tensor(v): + generation_input[k] = v.to(device) + with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): + past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input) + + # add text + generation_input, newlens, new_rope = self.prepare_prompts( + curr_kvlens=newlens, + curr_rope=new_rope, + prompts=[prompt], + tokenizer=tokenizer, + new_token_ids=new_token_ids, + ) + for k, v in generation_input.items(): + if torch.is_tensor(v): + generation_input[k] = v.to(device) + with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): + past_key_values = self.forward_cache_update_text(past_key_values, **generation_input) + + # decode + generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids) + for k, v in generation_input.items(): + if torch.is_tensor(v): + generation_input[k] = v.to(device) + with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): + unpacked_latent = self.generate_text( + past_key_values=past_key_values, + max_length=max_length, + do_sample=do_sample, + temperature=temperature, + end_token_id=new_token_ids['eos_token_id'], + **generation_input, + ) + output = tokenizer.decode(unpacked_latent[:,0]) + output = output.split('<|im_end|>')[0].split('<|im_start|>')[1] + + return output + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class MLPconnector(nn.Module): + def __init__(self, in_dim: int, out_dim: int, hidden_act: str): + super().__init__() + self.activation_fn = ACT2FN[hidden_act] + self.fc1 = nn.Linear(in_dim, out_dim) + self.fc2 = nn.Linear(out_dim, out_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class PositionEmbedding(nn.Module): + def __init__(self, max_num_patch_per_side, hidden_size): + super().__init__() + self.max_num_patch_per_side = max_num_patch_per_side + self.hidden_size = hidden_size + self.pos_embed = nn.Parameter( + torch.zeros(max_num_patch_per_side ** 2, hidden_size), + requires_grad=False + ) + self._init_weights() + + def _init_weights(self): + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) + + def forward(self, position_ids): + return self.pos_embed[position_ids] + + +class Qwen2Config(_Qwen2Config): + r""" + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + is_causal=True, + _attn_implementation="flash_attention_2", + qk_norm=True, + layer_module="Qwen2DecoderLayer", + freeze_und=False, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + rms_norm_eps=rms_norm_eps, + use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + use_sliding_window=use_sliding_window, + sliding_window=sliding_window, + max_window_layers=max_window_layers, + attention_dropout=attention_dropout, + is_causal=is_causal, + _attn_implementation=_attn_implementation, + **kwargs, + ) + self.qk_norm = qk_norm + self.layer_module = layer_module + self.freeze_und = freeze_und + + +@dataclass +class BaseNavitOutputWithPast(ModelOutput): + packed_query_sequence: torch.FloatTensor = None + past_key_values: Optional[NaiveCache] = None + + +def pad_sequence(tensor, pad_size): + H, L, D = tensor.shape + pad_tensor = tensor.new_zeros((H, pad_size, D)) + return torch.cat([tensor, pad_tensor], dim=1) + + +class PackedAttention(Qwen2Attention): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + if self.config.qk_norm: + self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def forward(self, *args, **kwargs): + if self.training: + return self.forward_train(*args, **kwargs) + else: + return self.forward_inference(*args, **kwargs) + + def forward_train( + self, + packed_sequence: torch.Tensor, + sample_lens: List[int], + attention_mask: List[torch.Tensor], + packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ): + packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim) + packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim) + + packed_query_states = self.q_norm(packed_query_states) + packed_key_states = self.k_norm(packed_key_states) + + packed_cos, packed_sin = packed_position_embeddings + packed_query_states, packed_key_states = apply_rotary_pos_emb( + packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 + ) + + if isinstance(attention_mask, List): + packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) + packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim) + packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) + packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim) + + unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1) + unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1) + unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1) + upacked_attn_output = [] + for query_states, key_states, value_states, attention_mask_per_sample in zip( + unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask + ): + with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): + attn_output = scaled_dot_product_attention( + query_states.to(torch.bfloat16).unsqueeze(0), + key_states.to(torch.bfloat16).unsqueeze(0), + value_states.to(torch.bfloat16).unsqueeze(0), + attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0), + ) + upacked_attn_output.append(attn_output.squeeze(0)) + packed_attn_output = torch.cat(upacked_attn_output, dim=1) + else: + pad_size = sum(sample_lens) - packed_query_states.shape[0] + packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size) + packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size) + packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size) + packed_attn_output = flex_attention( + packed_query_states.unsqueeze(0), + packed_key_states.unsqueeze(0), + packed_value_states.unsqueeze(0), + enable_gqa=True, + block_mask=attention_mask, + ) + end_index = packed_attn_output.shape[2] - pad_size + packed_attn_output = packed_attn_output[0, :, :end_index, :] + + packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size) + packed_attn_output = self.o_proj(packed_attn_output) + + return packed_attn_output + + def forward_inference( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: Optional[NaiveCache] = None, + key_values_lens: Optional[torch.Tensor] = None, + packed_key_value_indexes: Optional[torch.Tensor] = None, + update_past_key_values=True, + is_causal=True, + ): + packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) + packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + + packed_query_states = self.q_norm(packed_query_states) + packed_key_states = self.k_norm(packed_key_states) + + packed_cos, packed_sin = packed_query_position_embeddings + packed_query_states, packed_key_states = apply_rotary_pos_emb( + packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 + ) + + packed_query_states = packed_query_states.to(torch.bfloat16) + packed_key_states = packed_key_states.to(torch.bfloat16) + packed_value_states = packed_value_states.to(torch.bfloat16) + + if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: + past_key_states = past_key_values.key_cache[self.layer_idx] + past_value_states = past_key_values.value_cache[self.layer_idx] + + seqlens = sum(query_lens) + sum(key_values_lens) + merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim)) + merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim)) + merged_key_states[packed_query_indexes] = packed_key_states + merged_key_states[packed_key_value_indexes] = past_key_states + merged_value_states[packed_query_indexes] = packed_value_states + merged_value_states[packed_key_value_indexes] = past_value_states + key_values_lens = key_values_lens + query_lens + else: + merged_key_states = packed_key_states + merged_value_states = packed_value_states + key_values_lens = query_lens + + cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) + + packed_attn_output = flash_attn_varlen_func( + q=packed_query_states, + k=merged_key_states, + v=merged_value_states, + cu_seqlens_q=cu_seqlens_q.to(torch.int32), + cu_seqlens_k=cu_seqlens_k.to(torch.int32), + max_seqlen_q=max(query_lens).item(), + max_seqlen_k=max(key_values_lens).item(), + causal=is_causal, + ) + packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) + packed_attn_output = self.o_proj(packed_attn_output) + + if update_past_key_values: + past_key_values.key_cache[self.layer_idx] = merged_key_states + past_key_values.value_cache[self.layer_idx] = merged_value_states + + return packed_attn_output, past_key_values + + +class PackedAttentionMoT(Qwen2Attention): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + if self.config.qk_norm: + self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + self.q_norm_moe_gen = nn.Identity() + self.k_norm_moe_gen = nn.Identity() + + self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward(self, *args, **kwargs): + if self.training: + return self.forward_train(*args, **kwargs) + else: + return self.forward_inference(*args, **kwargs) + + def forward_train( + self, + packed_sequence: torch.Tensor, + sample_lens: List[int], + attention_mask, + packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], + packed_und_token_indexes: torch.LongTensor, + packed_gen_token_indexes: torch.LongTensor, + ): + packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim)) + packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)) + packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)) + + packed_sequence_und = packed_sequence[packed_und_token_indexes] + packed_sequence_gen = packed_sequence[packed_gen_token_indexes] + + packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und) + packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen) + + packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und) + packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen) + + packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und) + packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen) + + packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) + packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) + packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) + if self.config.freeze_und: + packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach() + + packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape) + packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape) + + packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes]) + if self.config.freeze_und: + packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach() + packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes]) + + packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes]) + if self.config.freeze_und: + packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach() + packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes]) + + packed_cos, packed_sin = packed_position_embeddings + packed_query_states_, packed_key_states_ = apply_rotary_pos_emb( + packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1 + ) + + if isinstance(attention_mask, List): + packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) + packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim) + packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) + packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim) + + unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1) + unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1) + unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1) + upacked_attn_output = [] + for query_states, key_states, value_states, attention_mask_per_sample in zip( + unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask + ): + with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): + attn_output = scaled_dot_product_attention( + query_states.to(torch.bfloat16).unsqueeze(0), + key_states.to(torch.bfloat16).unsqueeze(0), + value_states.to(torch.bfloat16).unsqueeze(0), + attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0), + ) + upacked_attn_output.append(attn_output.squeeze(0)) + packed_attn_output = torch.cat(upacked_attn_output, dim=1) + else: + pad_size = sum(sample_lens) - packed_query_states.shape[0] + packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size) + packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size) + packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size) + packed_attn_output = flex_attention( + packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim + packed_key_states_.unsqueeze(0), + packed_value_states.unsqueeze(0), + enable_gqa=True, + block_mask=attention_mask, + ) + end_index = packed_attn_output.shape[2] - pad_size + packed_attn_output = packed_attn_output[0, :, :end_index, :] + + packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim) + packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape) + packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes]) + packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes]) + + return packed_attn_output_ + + def forward_inference( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: Optional[NaiveCache] = None, + key_values_lens: Optional[torch.Tensor] = None, + packed_key_value_indexes: Optional[torch.Tensor] = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ): + if mode == 'und': + packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) + packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_query_states = self.q_norm(packed_query_states) + packed_key_states = self.k_norm(packed_key_states) + elif mode == 'gen': + packed_query_sequence = packed_query_sequence.to(torch.bfloat16) + packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim)) + packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim)) + packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim)) + + packed_text_query_sequence = packed_query_sequence[packed_text_indexes] + packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] + + packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence) + packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence) + + packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence) + packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence) + + packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence) + packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence) + + packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) + packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) + packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) + + packed_query_states = packed_query_states.to(torch.float32) + packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) + packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes]) + + packed_key_states = packed_key_states.to(torch.float32) + packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes]) + packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes]) + + packed_cos, packed_sin = packed_query_position_embeddings + packed_query_states, packed_key_states = apply_rotary_pos_emb( + packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 + ) + + packed_query_states = packed_query_states.to(torch.bfloat16) + packed_key_states = packed_key_states.to(torch.bfloat16) + packed_value_states = packed_value_states.to(torch.bfloat16) + + if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: + past_key_states = past_key_values.key_cache[self.layer_idx] + past_value_states = past_key_values.value_cache[self.layer_idx] + + seqlens = sum(query_lens) + sum(key_values_lens) + merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) + merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) + merged_key_states[packed_query_indexes] = packed_key_states + merged_key_states[packed_key_value_indexes] = past_key_states + merged_value_states[packed_query_indexes] = packed_value_states + merged_value_states[packed_key_value_indexes] = past_value_states + key_values_lens = key_values_lens + query_lens + else: + merged_key_states = packed_key_states + merged_value_states = packed_value_states + key_values_lens = query_lens + + cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) + + packed_attn_output = flash_attn_varlen_func( + q=packed_query_states, + k=merged_key_states, + v=merged_value_states, + cu_seqlens_q=cu_seqlens_q.to(torch.int32), + cu_seqlens_k=cu_seqlens_k.to(torch.int32), + max_seqlen_q=max(query_lens).item(), + max_seqlen_k=max(key_values_lens).item(), + causal=is_causal, + ) + packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) + if mode == 'und': + packed_attn_output = self.o_proj(packed_attn_output) + elif mode == 'gen': + packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes]) + packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes]) + + if update_past_key_values: + past_key_values.key_cache[self.layer_idx] = merged_key_states + past_key_values.value_cache[self.layer_idx] = merged_value_states + + return packed_attn_output, past_key_values + + +class Qwen2DecoderLayer(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = PackedAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, *args, **kwargs): + if self.training: + return self.forward_train(*args, **kwargs) + else: + return self.forward_inference(*args, **kwargs) + + def forward_train( + self, + packed_sequence: torch.Tensor, + sample_lens: List[int], + attention_mask, + packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + + residual = packed_sequence + packed_sequence = self.input_layernorm(packed_sequence) + + # Self Attention + packed_sequence = self.self_attn( + packed_sequence=packed_sequence, + sample_lens=sample_lens, + attention_mask=attention_mask, + packed_position_embeddings=packed_position_embeddings, + ) + packed_sequence = residual + packed_sequence + + # Fully Connected + residual = packed_sequence + packed_sequence = self.post_attention_layernorm(packed_sequence) + packed_sequence = self.mlp(packed_sequence) + packed_sequence = residual + packed_sequence + + return packed_sequence + + def forward_inference( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: Optional[NaiveCache] = None, + key_values_lens: Optional[torch.Tensor] = None, + packed_key_value_indexes: Optional[torch.Tensor] = None, + update_past_key_values=True, + is_causal=True, + ) -> BaseNavitOutputWithPast: + + residual = packed_query_sequence + packed_query_sequence = self.input_layernorm(packed_query_sequence) + + # Self Attention + packed_query_sequence, past_key_values = self.self_attn( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_embeddings=packed_query_position_embeddings, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + ) + packed_query_sequence = residual + packed_query_sequence + + # Fully Connected + residual = packed_query_sequence + packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) + packed_query_sequence = self.mlp(packed_query_sequence) + packed_query_sequence = residual + packed_query_sequence + + return packed_query_sequence, past_key_values + + +class Qwen2MoTDecoderLayer(nn.Module): + def __init__( + self, + config, + layer_idx: Optional[int] = None, + attn_module: Optional[Qwen2Attention] = PackedAttentionMoT, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.freeze_und = config.freeze_und + + self.self_attn = attn_module(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.mlp_moe_gen = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, *args, **kwargs): + if self.training: + return self.forward_train(*args, **kwargs) + else: + return self.forward_inference(*args, **kwargs) + + def forward_train( + self, + packed_sequence: torch.Tensor, + sample_lens: List[int], + attention_mask, + packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], + packed_und_token_indexes: torch.LongTensor, + packed_gen_token_indexes: torch.LongTensor, + ) -> torch.Tensor: + + residual = packed_sequence + packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape) + packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes]) + packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes]) + + # Self Attention + packed_sequence_ = self.self_attn( + packed_sequence=packed_sequence_, + sample_lens=sample_lens, + attention_mask=attention_mask, + packed_position_embeddings=packed_position_embeddings, + packed_und_token_indexes=packed_und_token_indexes, + packed_gen_token_indexes=packed_gen_token_indexes, + ) + if self.freeze_und: + packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() + packed_sequence = residual + packed_sequence_ + + # Fully Connected + residual = packed_sequence + packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape) + packed_sequence_[packed_und_token_indexes] = self.mlp( + self.post_attention_layernorm(packed_sequence[packed_und_token_indexes]) + ) + if self.freeze_und: + packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() + + packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen( + self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes]) + ) + packed_sequence = residual + packed_sequence_ + + return packed_sequence + + def forward_inference( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: Optional[NaiveCache] = None, + key_values_lens: Optional[torch.Tensor] = None, + packed_key_value_indexes: Optional[torch.Tensor] = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + + residual = packed_query_sequence + if mode == "und": + packed_query_sequence = self.input_layernorm(packed_query_sequence) + elif mode == "gen": + packed_query_sequence_ = torch.zeros_like(packed_query_sequence) + packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes]) + packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes]) + packed_query_sequence = packed_query_sequence_ + + # Self Attention + packed_query_sequence, past_key_values = self.self_attn( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_embeddings=packed_query_position_embeddings, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + mode=mode, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + packed_query_sequence = residual + packed_query_sequence + + # Fully Connected + residual = packed_query_sequence + if mode == "und": + packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) + packed_query_sequence = self.mlp(packed_query_sequence) + elif mode == "gen": + packed_text_query_sequence = packed_query_sequence[packed_text_indexes] + packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] + packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16) + packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16) + + packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) + packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence) + packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence) + packed_query_sequence = packed_query_sequence_ + + packed_query_sequence = residual + packed_query_sequence + return packed_query_sequence, past_key_values + + +class Qwen2MoEDecoderLayer(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = PackedAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.mlp_moe_gen = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, *args, **kwargs): + if self.training: + return self.forward_train(*args, **kwargs) + else: + return self.forward_inference(*args, **kwargs) + + def forward_train( + self, + packed_sequence: torch.Tensor, + sample_lens: List[int], + attention_mask, + packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], + packed_und_token_indexes: torch.LongTensor, + packed_gen_token_indexes: torch.LongTensor, + ) -> torch.Tensor: + + residual = packed_sequence + packed_sequence = self.input_layernorm(packed_sequence) + + # Self Attention + packed_sequence = self.self_attn( + packed_sequence=packed_sequence, + sample_lens=sample_lens, + attention_mask=attention_mask, + packed_position_embeddings=packed_position_embeddings, + ) + packed_sequence = residual + packed_sequence + + # Fully Connected + residual = packed_sequence + packed_sequence = self.post_attention_layernorm(packed_sequence) + + packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape) + packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes]) + packed_sequence_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes]) + packed_sequence_new[packed_und_token_indexes] = packed_sequence_und + packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen + + packed_sequence = residual + packed_sequence_new + + return packed_sequence + + def forward_inference( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: Optional[NaiveCache] = None, + key_values_lens: Optional[torch.Tensor] = None, + packed_key_value_indexes: Optional[torch.Tensor] = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + + residual = packed_query_sequence + packed_query_sequence = self.input_layernorm(packed_query_sequence) + + # Self Attention + packed_query_sequence, past_key_values = self.self_attn( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_embeddings=packed_query_position_embeddings, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + ) + packed_query_sequence = residual + packed_query_sequence + + # Fully Connected + residual = packed_query_sequence + packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) + if mode == "und": + packed_query_sequence = self.mlp(packed_query_sequence) + elif mode == "gen": + packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) + packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes]) + packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes]) + packed_query_sequence = packed_query_sequence_ + packed_query_sequence = residual + packed_query_sequence + + return packed_query_sequence, past_key_values + + +Decoder_layer_dict = { + "Qwen2DecoderLayer": Qwen2DecoderLayer, + "Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer, + "Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT), +} + + +class Qwen2Model(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.use_moe = 'Mo' in config.layer_module + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + layer_module = Decoder_layer_dict[config.layer_module] + self.layers = nn.ModuleList( + [layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.use_moe: + self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, *args, **kwargs): + if self.training: + return self.forward_train(*args, **kwargs) + else: + return self.forward_inference(*args, **kwargs) + + def forward_train( + self, + packed_sequence: torch.Tensor, + sample_lens: List[int], + attention_mask, + packed_position_ids: torch.Tensor, + packed_und_token_indexes: Optional[torch.LongTensor] = None, + packed_gen_token_indexes: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + + if self.config.freeze_und: + packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach() + + # create position embeddings to be shared across the decoder layers + cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0)) + cos = cos.squeeze(0) + sin = sin.squeeze(0) + packed_position_embeddings = (cos, sin) + + extra_inputs = {} + if self.use_moe: + assert packed_und_token_indexes is not None + if packed_gen_token_indexes is None: + packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0]) + extra_inputs.update( + packed_und_token_indexes=packed_und_token_indexes, + packed_gen_token_indexes=packed_gen_token_indexes, + ) + + for decoder_layer in self.layers: + packed_sequence = decoder_layer( + packed_sequence=packed_sequence, + sample_lens=sample_lens, + attention_mask=attention_mask, + packed_position_embeddings=packed_position_embeddings, + **extra_inputs + ) + + if self.use_moe: + packed_sequence_ = torch.zeros_like(packed_sequence) + packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes]) + if self.config.freeze_und: + packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() + packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes]) + return packed_sequence_ + else: + return self.norm(packed_sequence) + + def forward_inference( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_ids: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: Optional[NaiveCache] = None, + key_values_lens: Optional[torch.Tensor] = None, + packed_key_value_indexes: Optional[torch.Tensor] = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + + # create position embeddings to be shared across the decoder layers + cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0)) + cos = cos.squeeze(0) + sin = sin.squeeze(0) + packed_query_position_embeddings = (cos, sin) + + extra_inputs = {} + if self.use_moe: + extra_inputs.update(mode=mode) + if mode == 'gen': + assert packed_vae_token_indexes is not None + assert packed_text_indexes is not None + extra_inputs.update( + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + + for decoder_layer in self.layers: + packed_query_sequence, past_key_values = decoder_layer( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_embeddings=packed_query_position_embeddings, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + **extra_inputs, + ) + + if self.use_moe: + if mode == "und": + packed_query_sequence = self.norm(packed_query_sequence) + elif mode == "gen": + packed_query_sequence_ = torch.zeros_like(packed_query_sequence) + packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes]) + packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes]) + packed_query_sequence = packed_query_sequence_ + else: + packed_query_sequence = self.norm(packed_query_sequence) + + return BaseNavitOutputWithPast( + packed_query_sequence=packed_query_sequence, + past_key_values=past_key_values, + ) + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def init_moe(self): + for name, param in self.named_parameters(): + if "moe_gen" in name: + original_name = name.replace("_moe_gen", "") + param.data.copy_(self.state_dict()[original_name].data) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward(self, *args, **kwargs): + if self.training: + return self.forward_train(*args, **kwargs) + else: + return self.forward_inference(*args, **kwargs) + + def forward_train( + self, + packed_sequence: torch.Tensor, + sample_lens: List[int], + attention_mask, + packed_position_ids: torch.Tensor, + packed_und_token_indexes: Optional[torch.LongTensor] = None, + packed_gen_token_indexes: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + + outputs = self.model( + packed_sequence=packed_sequence, + sample_lens=sample_lens, + packed_position_ids=packed_position_ids, + attention_mask=attention_mask, + packed_und_token_indexes=packed_und_token_indexes, + packed_gen_token_indexes=packed_gen_token_indexes, + ) + return outputs + + def forward_inference( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_ids: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: Optional[NaiveCache] = None, + key_values_lens: Optional[torch.Tensor] = None, + packed_key_value_indexes: Optional[torch.Tensor] = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + + outputs = self.model( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_ids=packed_query_position_ids, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + mode=mode, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + + return outputs + + +class SiglipVisionConfig(_SiglipVisionConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + rope=True, + **kwargs, + ): + super().__init__( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_channels=num_channels, + image_size=image_size, + patch_size=patch_size, + hidden_act=hidden_act, + layer_norm_eps=layer_norm_eps, + attention_dropout=attention_dropout, + **kwargs) + + self.rope = rope + + +class RotaryEmbedding2D(torch.nn.Module): + def __init__(self, dim, max_h, max_w, base=10000): + super().__init__() + freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim + inv_freq = 1.0 / (base ** freq) + + grid_h = torch.arange(0, max_h) + grid_h = grid_h.to(inv_freq.dtype) + grid_h = grid_h[:, None].repeat(1, max_w) + + grid_w = torch.arange(0, max_w) + grid_w = grid_w.to(inv_freq.dtype) + grid_w = grid_w[None, :].repeat(max_h, 1) + + cos_h, sin_h = self._forward_one_side(grid_h, inv_freq) + cos_w, sin_w = self._forward_one_side(grid_w, inv_freq) + + self.register_buffer("cos_h", cos_h) + self.register_buffer("sin_h", sin_h) + self.register_buffer("cos_w", cos_w) + self.register_buffer("sin_w", sin_w) + + def _forward_one_side(self, grid, inv_freq): + freqs = grid[..., None] * inv_freq[None, None, :] + emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1) + return emb.cos(), emb.sin() + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + if not config.rope: + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def convert_conv2d_to_linear(self, config, meta=False): + if meta: + linear_patch_embedding = nn.Linear( + config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta' + ) + else: + linear_patch_embedding = nn.Linear( + config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True + ) + W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape( + self.embed_dim, config.num_channels * self.patch_size ** 2 + ) + linear_patch_embedding.weight.data = W + linear_patch_embedding.bias.data = self.patch_embedding.bias.data + del self.patch_embedding + self.patch_embedding = linear_patch_embedding + + def forward( + self, + packed_pixel_values: torch.FloatTensor, + packed_flattened_position_ids: torch.LongTensor + ) -> torch.Tensor: + + patch_embeds = self.patch_embedding(packed_pixel_values) + if not self.config.rope: + embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids) + else: + embeddings = patch_embeds + return embeddings + + +class SiglipFlashAttention2(SiglipAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.IntTensor, + max_seqlen: int, + cos_h: torch.Tensor = None, + sin_h: torch.Tensor = None, + cos_w: torch.Tensor = None, + sin_w: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + + total_q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(total_q_len, self.num_heads, self.head_dim) + key_states = key_states.view(total_q_len, self.num_heads, self.head_dim) + value_states = value_states.view(total_q_len, self.num_heads, self.head_dim) + + if self.config.rope: + qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:] + kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:] + qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h) + qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w) + query_states = torch.cat([qh, qw], dim=-1) + key_states = torch.cat([kh, kw], dim=-1) + + attn_output = flash_attn_varlen_func( + query_states.to(torch.bfloat16), + key_states.to(torch.bfloat16), + value_states.to(torch.bfloat16), + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=False, + ) + + attn_output = self.out_proj(attn_output.reshape(total_q_len, -1)) + return attn_output + + +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SiglipFlashAttention2(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.IntTensor, + max_seqlen: int, + cos_h: torch.Tensor = None, + sin_h: torch.Tensor = None, + cos_w: torch.Tensor = None, + sin_w: torch.Tensor = None + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + cos_h=cos_h, + sin_h=sin_h, + cos_w=cos_w, + sin_w=sin_w + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SiglipEncoder(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + cu_seqlens: torch.IntTensor, + max_seqlen: int, + cos_h: torch.Tensor = None, + sin_h: torch.Tensor = None, + cos_w: torch.Tensor = None, + sin_w: torch.Tensor = None, + ) -> torch.Tensor: + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen, + cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w) + + return hidden_states + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + if config.rope: + max_size = config.image_size // config.patch_size + dim_head = config.hidden_size // config.num_attention_heads + self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size) + + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + packed_pixel_values: torch.Tensor, + packed_flattened_position_ids: torch.LongTensor, + cu_seqlens: torch.IntTensor, + max_seqlen: int, + ) -> torch.Tensor: + hidden_states = self.embeddings( + packed_pixel_values=packed_pixel_values, + packed_flattened_position_ids=packed_flattened_position_ids + ) + + extra_inputs = {} + if self.config.rope: + extra_inputs.update( + cos_h = self.rope.cos_h[packed_flattened_position_ids], + sin_h = self.rope.sin_h[packed_flattened_position_ids], + cos_w = self.rope.cos_w[packed_flattened_position_ids], + sin_w = self.rope.sin_w[packed_flattened_position_ids] + ) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + **extra_inputs + ) + last_hidden_state = self.post_layernorm(last_hidden_state) + return last_hidden_state + + +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "packed_pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + packed_pixel_values: torch.Tensor, + packed_flattened_position_ids: torch.LongTensor, + cu_seqlens: torch.IntTensor, + max_seqlen: int, + ) -> torch.Tensor: + + return self.vision_model( + packed_pixel_values=packed_pixel_values, + packed_flattened_position_ids=packed_flattened_position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + +class MaxLongEdgeMinShortEdgeResize(torch.nn.Module): + """Resize the input image so that its longest side and shortest side are within a specified range, + ensuring that both sides are divisible by a specified stride. + + Args: + max_size (int): Maximum size for the longest edge of the image. + min_size (int): Minimum size for the shortest edge of the image. + stride (int): Value by which the height and width of the image must be divisible. + max_pixels (int): Maximum pixels for the full image. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted. + antialias (bool, optional): Whether to apply antialiasing (default is True). + """ + + def __init__( + self, + max_size: int, + min_size: int, + stride: int, + max_pixels: int, + interpolation=InterpolationMode.BICUBIC, + antialias=True + ): + super().__init__() + self.max_size = max_size + self.min_size = min_size + self.stride = stride + self.max_pixels = max_pixels + self.interpolation = interpolation + self.antialias = antialias + + def _make_divisible(self, value, stride): + """Ensure the value is divisible by the stride.""" + return max(stride, int(round(value / stride) * stride)) + + def _apply_scale(self, width, height, scale): + new_width = round(width * scale) + new_height = round(height * scale) + new_width = self._make_divisible(new_width, self.stride) + new_height = self._make_divisible(new_height, self.stride) + return new_width, new_height + + def forward(self, img, img_num=1): + """ + Args: + img (PIL Image): Image to be resized. + img_num (int): Number of images, used to change max_tokens. + Returns: + PIL Image or Tensor: Rescaled image with divisible dimensions. + """ + if isinstance(img, torch.Tensor): + height, width = img.shape[-2:] + else: + width, height = img.size + + scale = min(self.max_size / max(width, height), 1.0) + scale = max(scale, self.min_size / min(width, height)) + new_width, new_height = self._apply_scale(width, height, scale) + + # Ensure the number of pixels does not exceed max_pixels + if new_width * new_height > self.max_pixels / img_num: + scale = self.max_pixels / img_num / (new_width * new_height) + new_width, new_height = self._apply_scale(new_width, new_height, scale) + + # Ensure longest edge does not exceed max_size + if max(new_width, new_height) > self.max_size: + scale = self.max_size / max(new_width, new_height) + new_width, new_height = self._apply_scale(new_width, new_height, scale) + + return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias) + + +class ImageTransform: + def __init__( + self, + max_image_size, + min_image_size, + image_stride, + max_pixels=14*14*9*1024, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5] + ): + self.stride = image_stride + + self.resize_transform = MaxLongEdgeMinShortEdgeResize( + max_size=max_image_size, + min_size=min_image_size, + stride=image_stride, + max_pixels=max_pixels, + ) + self.to_tensor_transform = transforms.ToTensor() + self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True) + + def __call__(self, img, img_num=1): + img = self.resize_transform(img, img_num=img_num) + img = self.to_tensor_transform(img) + img = self.normalize_transform(img) + return img + + +def decolorization(image): + gray_image = image.convert('L') + return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image + + +def downscale(image, scale_factor): + new_width = int(round(image.width * scale_factor)) + new_height = int(round(image.height * scale_factor)) + new_width = max(1, new_width) + new_height = max(1, new_height) + return image.resize((new_width, new_height), resample=Image.BICUBIC) + + +def crop(image, crop_factors): + target_h, target_w = crop_factors + img_w, img_h = image.size + + if target_h > img_h or target_w > img_w: + raise ValueError("Crop size exceeds image dimensions") + + x = random.randint(0, img_w - target_w) + y = random.randint(0, img_h - target_h) + + return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]] + + +def motion_blur_opencv(image, kernel_size=15, angle=0): + # 线性核 + kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32) + kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32) + + # 旋转核 + center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5) + M = cv2.getRotationMatrix2D(center, angle, 1) + rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size)) + + # 归一化核 + rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1 + + img = np.array(image) + if img.ndim == 2: + blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT) + else: + # 对于彩色图像,各通道独立卷积 + blurred = np.zeros_like(img) + for c in range(img.shape[2]): + blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT) + + return Image.fromarray(blurred.astype(np.uint8)) + + +def shuffle_patch(image, num_splits, gap_size=2): + """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙""" + h_splits, w_splits = num_splits + img_w, img_h = image.size + + base_patch_h = img_h // h_splits + patch_heights = [base_patch_h] * (h_splits - 1) + patch_heights.append(img_h - sum(patch_heights)) + + base_patch_w = img_w // w_splits + patch_widths = [base_patch_w] * (w_splits - 1) + patch_widths.append(img_w - sum(patch_widths)) + + patches = [] + current_y = 0 + for i in range(h_splits): + current_x = 0 + patch_h = patch_heights[i] + for j in range(w_splits): + patch_w = patch_widths[j] + patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h)) + patches.append(patch) + current_x += patch_w + current_y += patch_h + + random.shuffle(patches) + + total_width = sum(patch_widths) + (w_splits - 1) * gap_size + total_height = sum(patch_heights) + (h_splits - 1) * gap_size + new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255)) + + current_y = 0 # 当前行的起始 Y 坐标 + patch_idx = 0 # 当前处理的块索引 + for i in range(h_splits): + current_x = 0 # 当前列的起始 X 坐标 + patch_h = patch_heights[i] # 当前行块的高度 + for j in range(w_splits): + # 取出打乱后的块 + patch = patches[patch_idx] + patch_w = patch_widths[j] # 当前列块的宽度 + # 粘贴块(左上角坐标为 (current_x, current_y)) + new_image.paste(patch, (current_x, current_y)) + # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙) + current_x += patch_w + gap_size + patch_idx += 1 + # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙) + current_y += patch_h + gap_size + + return new_image + + +def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)): + """ + 图像分割后随机空白部分patch,用于inpainting任务 + + 参数: + image: PIL.Image 输入图像(RGB模式) + h_splits: int 行分割数(垂直方向分割块数) + w_splits: int 列分割数(水平方向分割块数) + blank_ratio: float 空白patch的比例(0~1) + blank_color: tuple 空��区域的颜色(RGB,如白色(255,255,255)) + + 返回: + PIL.Image 处理后拼接的图像 + """ + h_splits, w_splits = num_splits + img_w, img_h = image.size + + base_patch_h = img_h // h_splits + patch_heights = [base_patch_h] * (h_splits - 1) + patch_heights.append(img_h - sum(patch_heights)) + + base_patch_w = img_w // w_splits + patch_widths = [base_patch_w] * (w_splits - 1) + patch_widths.append(img_w - sum(patch_widths)) + + patches = [] + current_y = 0 + for i in range(h_splits): + current_x = 0 + patch_h = patch_heights[i] + for j in range(w_splits): + patch_w = patch_widths[j] + patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h)) + patches.append(patch) + current_x += patch_w + current_y += patch_h + + total_patches = h_splits * w_splits + num_blank = int(total_patches * blank_ratio) + num_blank = max(0, min(num_blank, total_patches)) + blank_indices = random.sample(range(total_patches), num_blank) + + processed_patches = [] + for idx, patch in enumerate(patches): + if idx in blank_indices: + blank_patch = Image.new("RGB", patch.size, color=blank_color) + processed_patches.append(blank_patch) + else: + processed_patches.append(patch) + + # 创建结果图像(尺寸与原图一致) + result_image = Image.new("RGB", (img_w, img_h)) + current_y = 0 + patch_idx = 0 + for i in range(h_splits): + current_x = 0 + patch_h = patch_heights[i] + for j in range(w_splits): + # 取出处理后的patch + patch = processed_patches[patch_idx] + patch_w = patch_widths[j] + # 粘贴到原位置 + result_image.paste(patch, (current_x, current_y)) + current_x += patch_w + patch_idx += 1 + current_y += patch_h + + return result_image + + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + downsample: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(ModelMixin, ConfigMixin): + def __init__(self, params: AutoEncoderParams | None = None, **kwargs): + if params is None: + params = AutoEncoderParams(**kwargs) + super().__init__() + self.register_to_config(**asdict(params)) + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + @classmethod + def from_config(cls, config, **unused): + """ + Diffusers passes us `config` as a *dict* here. + Rebuild the AutoEncoderParams dataclass from that dict and + delegate to the normal constructor. + """ + # keep only keys that exist in AutoEncoderParams + allowed = {f.name for f in fields(AutoEncoderParams)} + params_dict = {k: v for k, v in config.items() if k in allowed} + + params = AutoEncoderParams(**params_dict) + return cls(params) + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + + +def load_ae(local_path: str) -> AutoEncoder: + ae_params = AutoEncoderParams( + resolution=256, + in_channels=3, + downsample=8, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ) + + # Loading the autoencoder + ae = AutoEncoder(ae_params) + + if local_path is not None: + sd = load_sft(local_path) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae, ae_params + + +VLM_THINK_SYSTEM_PROMPT = '''You should first think about the reasoning process in the mind and then provide the user with the answer. +The reasoning process is enclosed within tags, i.e. reasoning process here answer here''' + +GEN_THINK_SYSTEM_PROMPT = '''You should first think about the planning process in the mind and then generate the image. +The planning process is enclosed within tags, i.e. planning process here image here''' + + +class InterleaveInferencer: + def __init__(self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids): + self.model = model + self.vae_model = vae_model + self.tokenizer = tokenizer + self.vae_transform = vae_transform + self.vit_transform = vit_transform + self.new_token_ids = new_token_ids + + def _to_device(self, d, device): + """Recursively move every tensor in *d* to *device*.""" + for k, v in d.items(): + if torch.is_tensor(v): + d[k] = v.to(device) + return d + def to(self, device): + self.model = self.model.to(device) + self.vae_model = self.vae_model.to(device) + return self + + + def init_gen_context(self): + gen_context = { + 'kv_lens': [0], + 'ropes': [0], + 'past_key_values': NaiveCache(self.model.config.llm_config.num_hidden_layers), + } + return gen_context + + @torch.no_grad() + def update_context_text(self, text, gen_context): + # used for interleave data, currently only support 1 data inference, + + past_key_values = gen_context['past_key_values'] + kv_lens = gen_context['kv_lens'] + ropes = gen_context['ropes'] + generation_input, kv_lens, ropes = self.model.prepare_prompts( + curr_kvlens=kv_lens, + curr_rope=ropes, + prompts=[text], + tokenizer=self.tokenizer, + new_token_ids=self.new_token_ids, + ) + generation_input = self._to_device(generation_input, + next(self.model.parameters()).device) + past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input) + gen_context['kv_lens'] = kv_lens + gen_context['ropes'] = ropes + gen_context['past_key_values'] = past_key_values + + return gen_context + + @torch.no_grad() + def update_context_image(self, image, gen_context, vae=True, vit=True): + # used for interleave data, currently only support 1 data inference, + + assert vae or vit + past_key_values = gen_context['past_key_values'] + kv_lens = gen_context['kv_lens'] + ropes = gen_context['ropes'] + device = next(self.model.parameters()).device + if vae: + ## update vae + generation_input, kv_lens, ropes = self.model.prepare_vae_images( + curr_kvlens=kv_lens, + curr_rope=ropes, + images=[image], + transforms=self.vae_transform, + new_token_ids=self.new_token_ids, + ) + generation_input = self._to_device(generation_input, device) + past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input) + + if vit: + ## update vit + generation_input, kv_lens, ropes = self.model.prepare_vit_images( + curr_kvlens=kv_lens, + curr_rope=ropes, + images=[image], + transforms=self.vit_transform, + new_token_ids=self.new_token_ids, + ) + generation_input = self._to_device(generation_input, device) + past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input) + + gen_context['kv_lens'] = kv_lens + gen_context['ropes'] = ropes + gen_context['past_key_values'] = past_key_values + + return gen_context + + @torch.no_grad() + def gen_image( + self, + image_shape, + gen_context, + cfg_text_scale=4.0, + cfg_img_scale=1.5, + + cfg_text_precontext=None, + cfg_img_precontext=None, + cfg_interval=(0.4, 1.0), + cfg_renorm_min=0.0, + cfg_renorm_type="global", + + num_timesteps=50, + timestep_shift=3.0 + ): + # print(cfg_renorm_type) + device = next(self.model.parameters()).device + past_key_values = gen_context['past_key_values'] + kv_lens = gen_context['kv_lens'] + ropes = gen_context['ropes'] + generation_input = self.model.prepare_vae_latent( + curr_kvlens=kv_lens, + curr_rope=ropes, + image_sizes=[image_shape], + new_token_ids=self.new_token_ids, + ) + generation_input = self._to_device(generation_input, device) + + # text cfg + cfg_text_past_key_values = cfg_text_precontext['past_key_values'] + kv_lens_cfg = cfg_text_precontext['kv_lens'] + ropes_cfg = cfg_text_precontext['ropes'] + generation_input_cfg_text = self.model.prepare_vae_latent_cfg( + curr_kvlens=kv_lens_cfg, + curr_rope=ropes_cfg, + image_sizes=[image_shape], + ) + generation_input_cfg_text = self._to_device(generation_input_cfg_text, device) + + # img cfg + cfg_img_past_key_values = cfg_img_precontext['past_key_values'] + kv_lens_cfg = cfg_img_precontext['kv_lens'] + ropes_cfg = cfg_img_precontext['ropes'] + generation_input_cfg_img = self.model.prepare_vae_latent_cfg( + curr_kvlens=kv_lens_cfg, + curr_rope=ropes_cfg, + image_sizes=[image_shape], + ) + generation_input_cfg_img = self._to_device(generation_input_cfg_img, device) + + unpacked_latent = self.model.generate_image( + past_key_values=past_key_values, + cfg_text_past_key_values=cfg_text_past_key_values, + cfg_img_past_key_values=cfg_img_past_key_values, + num_timesteps=num_timesteps, + cfg_text_scale=cfg_text_scale, + cfg_img_scale=cfg_img_scale, + cfg_interval=cfg_interval, + cfg_renorm_min=cfg_renorm_min, + cfg_renorm_type=cfg_renorm_type, + timestep_shift=timestep_shift, + **generation_input, + cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'], + cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'], + cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'], + cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'], + cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'], + cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'], + cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'], + cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'], + ) + + image = self.decode_image(unpacked_latent[0], image_shape) + return image + + + def decode_image(self, latent, image_shape): + H, W = image_shape + h, w = H // self.model.latent_downsample, W // self.model.latent_downsample + + latent = latent.reshape(1, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel) + latent = torch.einsum("nhwpqc->nchpwq", latent) + latent = latent.reshape(1, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size) + image = self.vae_model.decode(latent) + image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255 + image = Image.fromarray((image).to(torch.uint8).cpu().numpy()) + + return image + + @torch.no_grad() + def gen_text(self, gen_context, max_length: int = 500, do_sample: bool = True, temperature: float = 1.0): + gen_context = deepcopy(gen_context) + past_key_values = gen_context['past_key_values'] + kv_lens = gen_context['kv_lens'] + ropes = gen_context['ropes'] + + generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids) + unpacked_latent = self.model.generate_text( + past_key_values=past_key_values, + max_length=max_length, + do_sample=do_sample, + temperature=temperature, + end_token_id=self.new_token_ids['eos_token_id'], + **generation_input, + ) + output = self.tokenizer.decode(unpacked_latent[:,0]) + output = output.split('<|im_end|>')[0].split('<|im_start|>')[1] + return output + + @torch.no_grad() + def interleave_inference( + self, + input_lists: List[Union[str, Image.Image]], + think=False, + understanding_output=False, + + max_think_token_n=1000, + do_sample=False, + text_temperature=0.3, + cfg_text_scale=3.0, + cfg_img_scale=1.5, + cfg_interval=[0.4, 1.0], + timestep_shift=3.0, + num_timesteps=50, + cfg_renorm_min=0.0, + cfg_renorm_type="global", + image_shapes=(1024, 1024), + ) -> List[Union[str, Image.Image]]: + + output_list = [] + gen_context = self.init_gen_context() + cfg_text_context = deepcopy(gen_context) + cfg_img_context = deepcopy(gen_context) + + with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): + if think: + if understanding_output: + system_prompt = VLM_THINK_SYSTEM_PROMPT + else: + system_prompt = GEN_THINK_SYSTEM_PROMPT + gen_context = self.update_context_text(system_prompt, gen_context) + cfg_img_context = self.update_context_text(system_prompt, cfg_img_context) + + for input_term in input_lists: + if isinstance(input_term, str): + cfg_text_context = deepcopy(gen_context) + gen_context = self.update_context_text(input_term, gen_context) + cfg_img_context = self.update_context_text(input_term, cfg_img_context) + + elif isinstance(input_term, Image.Image): + input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term)) + gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output) + + image_shapes = input_term.size[::-1] + cfg_text_context = deepcopy(gen_context) + + else: + raise ValueError(f"Unsupported input type: {type(input_term)}") + + if understanding_output: + gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n) + output_list.append(gen_text) + + else: + if think: + gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n) + gen_context = self.update_context_text(gen_text, gen_context) + output_list.append(gen_text) + + img = self.gen_image( + image_shapes, + gen_context, + cfg_text_precontext=cfg_text_context, + cfg_img_precontext=cfg_img_context, + + cfg_text_scale=cfg_text_scale, + cfg_img_scale=cfg_img_scale, + cfg_interval=cfg_interval, + timestep_shift=timestep_shift, + num_timesteps=num_timesteps, + cfg_renorm_min=cfg_renorm_min, + cfg_renorm_type=cfg_renorm_type, + ) + + output_list.append(img) + + return output_list + + def __call__( + self, + image: Optional[Image.Image] = None, + text: Optional[str] = None, + **kargs + ) -> Dict[str, Any]: + output_dict = {'image': None, 'text': None} + + if image is None and text is None: + print('Please provide at least one input: either an image or text.') + return output_dict + + input_list = [] + if image is not None: + input_list.append(image) + if text is not None: + input_list.append(text) + + output_list = self.interleave_inference(input_list, **kargs) + + for i in output_list: + if isinstance(i, Image.Image): + output_dict['image'] = i + elif isinstance(i, str): + output_dict['text'] = i + return output_dict + +# class BagelPipeline(DiffusionPipeline): +# """ +# A “naive” Bagel wrapper that replicates your notebook exactly. +# """ + +# model_cpu_offload_seq = "bagel_model" + +# def __init__( +# self, +# torch_dtype: torch.dtype = torch.bfloat16, +# ): +# super().__init__() +# self._dtype = torch_dtype +# self._built = False +# self._inferencer = None +# self.new_token_ids: List[int] = [] +# # Hard‐code default weights path; overridden by from_pretrained +# self.weights_root: Optional[str] = None +# self.register_to_config(weights_root=self.weights_root, torch_dtype=torch_dtype) +# repo_id = "ByteDance-Seed/BAGEL-7B-MoT" +# model_path = snapshot_download(repo_id=repo_id) +# print("loaded from ", model_path) +# # LLM config preparing +# llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json")) +# llm_config.qk_norm = True +# llm_config.tie_word_embeddings = False +# llm_config.layer_module = "Qwen2MoTDecoderLayer" + +# # ViT config preparing +# vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json")) +# vit_config.rope = False +# vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1 + +# # VAE loading +# vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors")) + +# # Bagel config preparing +# config = BagelConfig( +# visual_gen=True, +# visual_und=True, +# llm_config=llm_config, +# vit_config=vit_config, +# vae_config=vae_config, +# vit_max_num_patch_per_side=70, +# connector_act='gelu_pytorch_tanh', +# latent_patch_size=2, +# max_latent_size=64, +# ) + +# with init_empty_weights(): +# language_model = Qwen2ForCausalLM(llm_config) +# vit_model = SiglipVisionModel(vit_config) +# model = Bagel(language_model, vit_model, config) +# model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) + +# # Tokenizer Preparing +# tokenizer = Qwen2Tokenizer.from_pretrained(model_path) +# tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) + +# # Image Transform Preparing +# vae_transform = ImageTransform(1024, 512, 16) +# vit_transform = ImageTransform(980, 224, 14) + +# # set cuda device to 4 + +# max_mem_per_gpu = "40GiB" # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU. + +# device_map = infer_auto_device_map( +# model, +# max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())}, +# no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], +# ) +# print(device_map) + +# same_device_modules = [ +# 'language_model.model.embed_tokens', +# 'time_embedder', +# 'latent_pos_embed', +# 'vae2llm', +# 'llm2vae', +# 'connector', +# 'vit_pos_embed' +# ] + +# if torch.cuda.device_count() == 1: +# first_device = device_map.get(same_device_modules[0], "cuda:0") +# for k in same_device_modules: +# if k in device_map: +# device_map[k] = first_device +# else: +# device_map[k] = "cuda:0" +# else: +# first_device = device_map.get(same_device_modules[0]) +# for k in same_device_modules: +# if k in device_map: +# device_map[k] = first_device + +# # Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8 +# model = load_checkpoint_and_dispatch( +# model, +# checkpoint=os.path.join(model_path, "ema.safetensors"), +# device_map=device_map, +# offload_buffers=True, +# dtype=torch.bfloat16, +# force_hooks=True, +# offload_folder="/tmp/offload" +# ) + +# model = model.eval() +# print('Model loaded') + +# self._inferencer = InterleaveInferencer( +# model=model, +# vae_model=vae_model, +# tokenizer=tokenizer, +# vae_transform=vae_transform, +# vit_transform=vit_transform, +# new_token_ids=new_token_ids +# ) + +# seed = 42 +# random.seed(seed) +# np.random.seed(seed) +# torch.manual_seed(seed) +# if torch.cuda.is_available(): +# torch.cuda.manual_seed(seed) +# torch.cuda.manual_seed_all(seed) +# torch.backends.cudnn.deterministic = True +# torch.backends.cudnn.benchmark = False + + +# @torch.no_grad() +# def __call__( +# self, +# prompt: str, +# think=False, +# cfg_text_scale: float = 4.0, +# cfg_img_scale: float = 1.0, +# cfg_interval=(0.4, 1.0), +# timestep_shift: float = 3.0, +# num_timesteps: int = 50, +# cfg_renorm_min: float = 0.0, +# cfg_renorm_type: str = "global", +# seed: Optional[int] = None, +# output_type: str = "pil", +# return_dict: bool = True, +# **unused, +# ): + +# if seed is not None: +# torch.manual_seed(seed) +# if torch.cuda.is_available(): +# torch.cuda.manual_seed_all(seed) + +# inference_kwargs = dict( +# text=prompt, +# think=think, +# cfg_text_scale=cfg_text_scale, +# cfg_img_scale=cfg_img_scale, +# cfg_interval=list(cfg_interval), +# timestep_shift=timestep_shift, +# num_timesteps=num_timesteps, +# cfg_renorm_min=cfg_renorm_min, +# cfg_renorm_type=cfg_renorm_type, +# ) +# result = self._inferencer(**inference_kwargs) +# image = result["image"] if isinstance(result, dict) else result +# if return_dict: +# return {"images": [image]} +# return [image] + +class BagelPipeline(DiffusionPipeline): + model_cpu_offload_seq = "bagel_model" + + def __init__(self, bagel_model, vae, tokenizer): + super().__init__() + self.register_modules( + bagel_model = bagel_model, + vae = vae, + tokenizer = tokenizer, + ) + tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) + self._inferencer = InterleaveInferencer( + model = bagel_model, + vae_model = vae, + tokenizer = tokenizer, + vae_transform= ImageTransform(1024, 512, 16), + vit_transform= ImageTransform(980, 224, 14), + new_token_ids= new_token_ids, + ) + + def __call__(self, prompt: str, **infer_kwargs): + result = self._inferencer(text=prompt, **infer_kwargs) + img = result["image"] if isinstance(result, dict) else result + return {"images": [img]} + + + def to(self, device): + super().to(device) # moves registered modules + if hasattr(self, "_inferencer"): + self._inferencer.to(device) + return self +