|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch BailingMoE model.""" |
|
|
|
import math |
|
import warnings |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers.modeling_attn_mask_utils import ( |
|
AttentionMaskConverter, |
|
_prepare_4d_attention_mask, |
|
_prepare_4d_causal_attention_mask, |
|
_prepare_4d_causal_attention_mask_for_sdpa, |
|
) |
|
from transformers.modeling_outputs import MoeModelOutputWithPast |
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_flash_attn_2_available, |
|
is_flash_attn_greater_or_equal_2_10, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
from transformers.utils.import_utils import is_torch_fx_available |
|
from .configuration_bailing_moe_v2 import BailingMoeV2Config |
|
from transformers.generation.utils import GenerationMixin |
|
from dataclasses import dataclass |
|
from transformers.utils import ModelOutput |
|
|
|
|
|
if is_flash_attn_2_available(): |
|
from flash_attn import flash_attn_func, flash_attn_varlen_func |
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
|
|
|
|
|
|
|
if is_torch_fx_available(): |
|
if not is_torch_greater_or_equal_than_1_13: |
|
import torch.fx |
|
|
|
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "BailingMoeV2Config" |
|
|
|
|
|
def roll_tensor(tensor, shifts=-1, dims=-1, fill_value=0): |
|
"""Roll the tensor input along the given dimension(s). |
|
Inserted elements are set to be 0.0. |
|
""" |
|
rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims) |
|
rolled_tensor.select(dims, shifts).fill_(fill_value) |
|
return rolled_tensor, rolled_tensor.sum() |
|
|
|
|
|
@dataclass |
|
class MoEV2CausalLMOutputWithPast(ModelOutput): |
|
""" |
|
Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden |
|
states terms, to train a MoE model. |
|
|
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss (for next-token prediction). |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
|
`past_key_values` input) to speed up sequential decoding. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): |
|
z_loss for the sparse modules. |
|
aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): |
|
aux_loss for the sparse modules. |
|
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. |
|
|
|
Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse |
|
modules. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits: Optional[torch.FloatTensor] = None |
|
past_key_values: Optional[Cache] = None |
|
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None |
|
attentions: Optional[tuple[torch.FloatTensor, ...]] = None |
|
z_loss: Optional[torch.FloatTensor] = None |
|
aux_loss: Optional[torch.FloatTensor] = None |
|
router_logits: Optional[tuple[torch.FloatTensor]] = None |
|
mtp_loss: Optional[torch.FloatTensor] = None |
|
mtp_logits: Optional[tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
|
class MoeV2ModelOutputWithPast(MoeModelOutputWithPast): |
|
|
|
def __init__(self, mtp_hidden_states=None, **kwargs): |
|
super().__init__(**kwargs) |
|
self.mtp_hidden_states = mtp_hidden_states |
|
|
|
|
|
def _get_unpad_data(attention_mask): |
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
|
max_seqlen_in_batch = seqlens_in_batch.max().item() |
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) |
|
return ( |
|
indices, |
|
cu_seqlens, |
|
max_seqlen_in_batch, |
|
) |
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): |
|
warnings.warn( |
|
"Calling `transformers.models.BailingMoeV2.modeling_BailingMoeV2._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask" |
|
) |
|
return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) |
|
|
|
|
|
def _make_causal_mask( |
|
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 |
|
): |
|
warnings.warn( |
|
"Calling `transformers.models.BailingMoeV2.modeling_BailingMoeV2._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.BailingMoeV2.modeling_BailingMoeV2.AttentionMaskConverter._make_causal_mask" |
|
) |
|
return AttentionMaskConverter._make_causal_mask( |
|
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length |
|
) |
|
|
|
|
|
class BailingMoeV2RMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
BailingMoeV2RMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
input_dtype = hidden_states.dtype |
|
hidden_states = hidden_states.to(torch.float32) |
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
ALL_LAYERNORM_LAYERS.append(BailingMoeV2RMSNorm) |
|
|
|
|
|
class BailingMoeV2RotaryEmbedding(nn.Module): |
|
def __init__(self, config: BailingMoeV2Config, device=None): |
|
super().__init__() |
|
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
else: |
|
self.rope_type = "default" |
|
self.max_seq_len_cached = config.max_position_embeddings |
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
|
self.config = config |
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.original_inv_freq = self.inv_freq |
|
|
|
@torch.no_grad() |
|
@dynamic_rope_update |
|
def forward(self, x, position_ids): |
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() * self.attention_scaling |
|
sin = emb.sin() * self.attention_scaling |
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
|
|
|
|
rotary_dim = cos.shape[-1] |
|
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] |
|
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
|
|
|
|
|
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) |
|
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) |
|
|
|
|
|
q_embed = torch.cat([q_embed, q_pass], dim=-1) |
|
k_embed = torch.cat([k_embed, k_pass], dim=-1) |
|
return q_embed, k_embed |
|
|
|
|
|
class BailingMoeV2MLP(nn.Module): |
|
def __init__(self, config: BailingMoeV2Config, intermediate_size: int): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = intermediate_size |
|
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
def forward(self, x): |
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
class BailingMoeV2Gate(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.top_k = config.num_experts_per_tok |
|
self.num_experts = config.num_experts |
|
|
|
self.n_group = config.n_group |
|
self.topk_group = config.topk_group |
|
|
|
|
|
self.gating_dim = config.hidden_size |
|
self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim))) |
|
self.routed_scaling_factor = config.routed_scaling_factor |
|
|
|
self.register_buffer("expert_bias", torch.zeros((self.num_experts))) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self) -> None: |
|
import torch.nn.init as init |
|
|
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
|
|
|
def group_limited_topk( |
|
self, |
|
scores: torch.Tensor, |
|
): |
|
num_tokens, _ = scores.size() |
|
|
|
group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) |
|
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] |
|
group_mask = torch.zeros_like(group_scores) |
|
group_mask.scatter_(1, group_idx, 1) |
|
|
|
|
|
score_mask = ( |
|
group_mask.unsqueeze(-1) |
|
.expand(num_tokens, self.n_group, self.num_experts // self.n_group) |
|
.reshape(num_tokens, -1) |
|
) |
|
|
|
masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf')) |
|
probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1) |
|
|
|
return probs, top_indices |
|
|
|
def forward(self, hidden_states): |
|
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
|
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) |
|
|
|
scores = torch.sigmoid(logits.float()).type_as(logits) |
|
|
|
scores_for_routing = scores + self.expert_bias |
|
_, topk_idx = self.group_limited_topk(scores_for_routing) |
|
|
|
scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits) |
|
|
|
topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores |
|
topk_weight = topk_weight * self.routed_scaling_factor |
|
|
|
return topk_idx, topk_weight, logits |
|
|
|
|
|
class BailingMoeV2SparseMoeBlock(nn.Module): |
|
""" |
|
A mixed expert module containing shared experts. |
|
""" |
|
|
|
def __init__(self, config: BailingMoeV2Config): |
|
super().__init__() |
|
self.config = config |
|
self.num_experts_per_tok = config.num_experts_per_tok |
|
self._setup_experts() |
|
self.gate = BailingMoeV2Gate(config) |
|
if config.num_shared_experts is not None: |
|
self.shared_experts = BailingMoeV2MLP( |
|
config=config, intermediate_size=config.moe_intermediate_size * config.num_shared_experts |
|
) |
|
|
|
def _setup_experts(self): |
|
self.experts = nn.ModuleList( |
|
[ |
|
BailingMoeV2MLP(config=self.config, intermediate_size=self.config.moe_intermediate_size) |
|
for _ in range(self.config.num_experts) |
|
] |
|
) |
|
|
|
def forward(self, hidden_states): |
|
identity = hidden_states |
|
bsz, seq_len, h = hidden_states.shape |
|
topk_idx, topk_weight, router_logits = self.gate(hidden_states) |
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
|
flat_topk_idx = topk_idx.view(-1) |
|
if self.training: |
|
hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0) |
|
y = torch.empty_like(hidden_states) |
|
for i, expert in enumerate(self.experts): |
|
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) |
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) |
|
y = y.to(hidden_states.dtype).view(bsz, seq_len, h) |
|
else: |
|
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h) |
|
if self.config.num_shared_experts is not None: |
|
y = y + self.shared_experts(identity) |
|
return y, (router_logits.view(bsz, seq_len, -1), topk_idx.view(bsz, seq_len, -1)) |
|
|
|
@torch.no_grad() |
|
def moe_infer(self, x, topk_ids, topk_weight): |
|
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) |
|
cnts.scatter_(1, topk_ids, 1) |
|
tokens_per_expert = cnts.sum(dim=0) |
|
idxs = topk_ids.view(-1).argsort() |
|
sorted_tokens = x[idxs // topk_ids.shape[1]] |
|
tokens_per_expert = tokens_per_expert.cpu().numpy() |
|
outputs = [] |
|
start_idx = 0 |
|
for i, num_tokens in enumerate(tokens_per_expert): |
|
end_idx = start_idx + num_tokens |
|
if num_tokens == 0: |
|
continue |
|
expert = self.experts[i] |
|
tokens_for_this_expert = sorted_tokens[start_idx:end_idx] |
|
expert_out = expert(tokens_for_this_expert) |
|
outputs.append(expert_out.to(x.device)) |
|
start_idx = end_idx |
|
|
|
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) |
|
new_x = torch.empty_like(outs) |
|
new_x[idxs] = outs |
|
final_out = ( |
|
new_x.view(*topk_ids.shape, -1) |
|
.type(topk_weight.dtype) |
|
.mul_(topk_weight.unsqueeze(dim=-1)) |
|
.sum(dim=1) |
|
.type(new_x.dtype) |
|
) |
|
return final_out |
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
class BailingMoeV2Attention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config: BailingMoeV2Config, layer_idx: Optional[int] = None): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
if layer_idx is None: |
|
logger.warning_once( |
|
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " |
|
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
|
"when creating this class." |
|
) |
|
|
|
self.attention_dropout = config.attention_dropout |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = config.head_dim or self.hidden_size // self.num_heads |
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 |
|
self.rope_dim = int(self.head_dim * partial_rotary_factor) |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.rope_theta = config.rope_theta |
|
self.is_causal = True |
|
|
|
self.query_key_value = nn.Linear( |
|
self.hidden_size, |
|
(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, |
|
bias=config.use_qkv_bias, |
|
) |
|
|
|
if self.config.use_qk_norm: |
|
self.query_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
self.key_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
qkv = self.query_key_value(hidden_states) |
|
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) |
|
|
|
query_states, key_states, value_states = qkv.split( |
|
[self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 |
|
) |
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
|
|
if self.config.use_qk_norm: |
|
query_states = self.query_layernorm(query_states) |
|
key_states = self.key_layernorm(key_states) |
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
if self.layer_idx is None: |
|
raise ValueError( |
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
|
"with a layer index." |
|
) |
|
cache_kwargs = {"sin": sin, "cos": cos} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
|
) |
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1) |
|
|
|
attn_output = self.dense(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
|
|
class BailingMoeV2FlashAttention2(BailingMoeV2Attention): |
|
""" |
|
BailingMoeV2 flash attention module. This module inherits from `BailingMoeV2Attention` as the weights of the module stays |
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
|
flash attention and deal with padding tokens in case the input contains any of them. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
output_attentions = False |
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
|
|
|
|
|
|
|
|
qkv = self.query_key_value(hidden_states) |
|
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) |
|
|
|
query_states, key_states, value_states = qkv.split( |
|
[self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 |
|
) |
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
|
|
if self.config.use_qk_norm: |
|
query_states = self.query_layernorm(query_states) |
|
key_states = self.key_layernorm(key_states) |
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
cache_kwargs = {"sin": sin, "cos": cos} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
|
|
dropout_rate = self.attention_dropout if self.training else 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype |
|
if input_dtype == torch.float32: |
|
|
|
if hasattr(self.config, "_pre_quantization_dtype"): |
|
target_dtype = self.config._pre_quantization_dtype |
|
elif torch.is_autocast_enabled(): |
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
else: |
|
target_dtype = self.query_key_value.weight.dtype |
|
|
|
logger.warning_once( |
|
f"The input hidden states seems to be silently casted in float32, this might be related to" |
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
|
f" {target_dtype}." |
|
) |
|
|
|
query_states = query_states.to(target_dtype) |
|
key_states = key_states.to(target_dtype) |
|
value_states = value_states.to(target_dtype) |
|
|
|
attn_output = self._flash_attention_forward( |
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate |
|
) |
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() |
|
attn_output = self.dense(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
def _flash_attention_forward( |
|
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None |
|
): |
|
""" |
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
|
first unpad the input, then computes the attention scores and pad the final attention scores. |
|
|
|
Args: |
|
query_states (`torch.Tensor`): |
|
Input query states to be passed to Flash Attention API |
|
key_states (`torch.Tensor`): |
|
Input key states to be passed to Flash Attention API |
|
value_states (`torch.Tensor`): |
|
Input value states to be passed to Flash Attention API |
|
attention_mask (`torch.Tensor`): |
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
|
position of padding tokens and 1 for the position of non-padding tokens. |
|
dropout (`int`, *optional*): |
|
Attention dropout |
|
softmax_scale (`float`, *optional*): |
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
|
query_length (`int`): |
|
The length of the query sequence in terms of tokens. This represents the number of tokens in the |
|
`query_states` tensor along the sequence dimension. It is used to determine the effective sequence |
|
length for attention computations. |
|
""" |
|
if not self._flash_attn_uses_top_left_mask: |
|
causal = self.is_causal |
|
else: |
|
|
|
causal = self.is_causal and query_length != 1 |
|
|
|
|
|
if attention_mask is not None: |
|
batch_size = query_states.shape[0] |
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
|
query_states, key_states, value_states, attention_mask, query_length |
|
) |
|
|
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
|
|
attn_output_unpad = flash_attn_varlen_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_in_batch_q, |
|
max_seqlen_k=max_seqlen_in_batch_k, |
|
dropout_p=dropout, |
|
softmax_scale=softmax_scale, |
|
causal=causal, |
|
) |
|
|
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
|
else: |
|
attn_output = flash_attn_func( |
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal |
|
) |
|
|
|
return attn_output |
|
|
|
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): |
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
|
|
|
key_layer = index_first_axis( |
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
|
) |
|
value_layer = index_first_axis( |
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
|
) |
|
if query_length == kv_seq_len: |
|
query_layer = index_first_axis( |
|
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k |
|
) |
|
cu_seqlens_q = cu_seqlens_k |
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k |
|
indices_q = indices_k |
|
elif query_length == 1: |
|
max_seqlen_in_batch_q = 1 |
|
cu_seqlens_q = torch.arange( |
|
batch_size + 1, dtype=torch.int32, device=query_layer.device |
|
) |
|
indices_q = cu_seqlens_q[:-1] |
|
query_layer = query_layer.squeeze(1) |
|
else: |
|
|
|
attention_mask = attention_mask[:, -query_length:] |
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) |
|
|
|
return ( |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
indices_q, |
|
(cu_seqlens_q, cu_seqlens_k), |
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
|
) |
|
|
|
|
|
|
|
class BailingMoeV2SdpaAttention(BailingMoeV2Attention): |
|
""" |
|
BailingMoeV2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from |
|
`BailingMoeV2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to |
|
SDPA API. |
|
""" |
|
|
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
if output_attentions: |
|
|
|
logger.warning_once( |
|
"BailingMoeV2Model is using BailingMoeV2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " |
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
|
) |
|
return super().forward( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
qkv = self.query_key_value(hidden_states) |
|
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) |
|
|
|
query_states, key_states, value_states = qkv.split( |
|
[self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 |
|
) |
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
|
|
if self.config.use_qk_norm: |
|
query_states = self.query_layernorm(query_states) |
|
key_states = self.key_layernorm(key_states) |
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
cache_kwargs = {"sin": sin, "cos": cos} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
if attention_mask is not None: |
|
kv_seq_len = key_states.shape[-2] |
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
|
) |
|
|
|
|
|
|
|
if query_states.device.type == "cuda" and attention_mask is not None: |
|
query_states = query_states.contiguous() |
|
key_states = key_states.contiguous() |
|
value_states = value_states.contiguous() |
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attn_mask=attention_mask, |
|
dropout_p=self.attention_dropout if self.training else 0.0, |
|
|
|
is_causal=self.is_causal and attention_mask is None and q_len > 1, |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, -1) |
|
|
|
attn_output = self.dense(attn_output) |
|
|
|
return attn_output, None, past_key_value |
|
|
|
|
|
ATTENTION_CLASSES = { |
|
"eager": BailingMoeV2Attention, |
|
"flash_attention_2": BailingMoeV2FlashAttention2, |
|
"sdpa": BailingMoeV2SdpaAttention, |
|
} |
|
|
|
|
|
class BailingMoeV2MTPLayer(nn.Module): |
|
def __init__(self, config: BailingMoeV2Config, layer_idx: int): |
|
super().__init__() |
|
self.layer_idx = layer_idx |
|
self.input_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.enorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) |
|
self.post_attention_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.attention = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
|
self.mlp = BailingMoeV2SparseMoeBlock(config) |
|
|
|
self.hnorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.final_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
input_embeds, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
output_router_logits: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
input_embeds = self.enorm(input_embeds) |
|
hidden_states = self.hnorm(hidden_states) |
|
hidden_states = self.eh_proj(torch.cat([input_embeds, hidden_states], dim=-1)) |
|
residual = hidden_states |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.attention( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
position_embeddings=position_embeddings, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
if isinstance(hidden_states, tuple): |
|
hidden_states, router_logits = hidden_states |
|
else: |
|
router_logits = None |
|
hidden_states = residual + hidden_states.to(residual.device) |
|
hidden_states = self.final_layernorm(hidden_states) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
if output_router_logits: |
|
outputs += (router_logits,) |
|
|
|
return outputs |
|
|
|
|
|
class BailingMoeV2DecoderLayer(nn.Module): |
|
def __init__(self, config: BailingMoeV2Config, layer_idx: int): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
|
|
self.attention = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
|
|
|
self.mlp = ( |
|
BailingMoeV2SparseMoeBlock(config) |
|
if (config.num_experts is not None and layer_idx >= config.first_k_dense_replace) |
|
else BailingMoeV2MLP(config=config, intermediate_size=config.intermediate_size) |
|
) |
|
self.input_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
output_router_logits: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`, *optional*): |
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, |
|
query_sequence_length, key_sequence_length)` if default attention is used. |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.n_positions - 1]`. |
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): |
|
cached past key and value projection states |
|
output_attentions (`bool`, *optional*): |
|
Whether to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
output_router_logits (`bool`, *optional*): |
|
Whether or not to return the logits of all the routers. They are useful for computing the router loss, |
|
and should not be returned during inference. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
(see `past_key_values`). |
|
""" |
|
residual = hidden_states |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.attention( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
position_embeddings=position_embeddings, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
if isinstance(hidden_states, tuple): |
|
hidden_states, router_logits = hidden_states |
|
else: |
|
router_logits = None |
|
hidden_states = residual + hidden_states.to(residual.device) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
if output_router_logits: |
|
outputs += (router_logits,) |
|
|
|
return outputs |
|
|
|
|
|
BAILINGMOEV2_START_DOCSTRING = r""" |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`BailingMoeV2Config`]): |
|
Model configuration class with all the parameters of the model. Initializing with a config file does not |
|
load the weights associated with the model, only the configuration. Check out the |
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare BailingMoeV2 Model outputting raw hidden-states without any specific head on top.", |
|
BAILINGMOEV2_START_DOCSTRING, |
|
) |
|
class BailingMoeV2PreTrainedModel(PreTrainedModel): |
|
config_class = BailingMoeV2Config |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["BailingMoeV2DecoderLayer"] |
|
_skip_keys_device_placement = "past_key_values" |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_cache_class = True |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
BAILINGMOEV2_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see |
|
`past_key_values`). |
|
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
|
information on the default strategy. |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.n_positions - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` |
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
|
|
|
Two formats are allowed: |
|
- a [`~cache_utils.Cache`] instance; |
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
|
cache format. |
|
|
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
|
legacy cache format will be returned. |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't |
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` |
|
of shape `(batch_size, sequence_length)`. |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
model's internal embedding lookup matrix. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
|
`past_key_values`). |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare BailingMoeV2 Model outputting raw hidden-states without any specific head on top.", |
|
BAILINGMOEV2_START_DOCSTRING, |
|
) |
|
class BailingMoeV2Model(BailingMoeV2PreTrainedModel): |
|
""" |
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BailingMoeV2DecoderLayer`] |
|
|
|
Args: |
|
config: BailingMoeV2Config |
|
""" |
|
|
|
def __init__(self, config: BailingMoeV2Config): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
self.num_nextn_predict_layers = config.num_nextn_predict_layers |
|
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = [] |
|
for layer_idx in range(config.num_hidden_layers + config.num_nextn_predict_layers): |
|
layer_cls = BailingMoeV2DecoderLayer if layer_idx < config.num_hidden_layers else BailingMoeV2MTPLayer |
|
self.layers.append(layer_cls(config, layer_idx)) |
|
|
|
self.layers = nn.ModuleList(self.layers) |
|
|
|
self._use_sdpa = config._attn_implementation == "sdpa" |
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
|
self.norm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.rotary_emb = BailingMoeV2RotaryEmbedding(config=config) |
|
self.gradient_checkpointing = False |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
self.word_embeddings = value |
|
|
|
@add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_router_logits: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[Tuple, MoeV2ModelOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
output_router_logits = ( |
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
batch_size, seq_length = input_ids.shape[:2] |
|
elif inputs_embeds is not None: |
|
batch_size, seq_length = inputs_embeds.shape[:2] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." |
|
) |
|
use_cache = False |
|
|
|
if use_cache and past_key_values is None: |
|
past_key_values = DynamicCache() |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
if self._use_flash_attention_2: |
|
|
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
|
elif self._use_sdpa and not output_attentions: |
|
|
|
|
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
attention_mask, |
|
(batch_size, seq_length), |
|
inputs_embeds, |
|
past_seen_tokens, |
|
) |
|
else: |
|
|
|
attention_mask = _prepare_4d_causal_attention_mask( |
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens |
|
) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
all_router_logits = () if output_router_logits else None |
|
next_decoder_cache = None |
|
layers = self.layers[: -self.num_nextn_predict_layers] if self.num_nextn_predict_layers > 0 else self.layers |
|
mtp_layers = self.layers[-self.num_nextn_predict_layers :] if self.num_nextn_predict_layers > 0 else None |
|
|
|
for decoder_layer in layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
output_router_logits, |
|
use_cache, |
|
position_embeddings, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
output_router_logits=output_router_logits, |
|
use_cache=use_cache, |
|
position_embeddings=position_embeddings, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
if output_router_logits and layer_outputs[-1] is not None: |
|
all_router_logits += (layer_outputs[-1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
main_hidden_states = hidden_states |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (main_hidden_states,) |
|
|
|
mtp_hidden_states = None |
|
|
|
if mtp_layers: |
|
for decoder_layer in mtp_layers: |
|
input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1) |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
inputs_embeds, |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
output_router_logits, |
|
use_cache, |
|
position_embeddings, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
inputs_embeds, |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
output_router_logits=output_router_logits, |
|
use_cache=use_cache, |
|
position_embeddings=position_embeddings, |
|
) |
|
if mtp_hidden_states is None: |
|
mtp_hidden_states = [] |
|
hidden_states = layer_outputs[0] |
|
mtp_hidden_states.append(hidden_states) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if use_cache: |
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
if output_router_logits and layer_outputs[-1] is not None: |
|
all_router_logits += (layer_outputs[-1],) |
|
|
|
next_cache = None |
|
if use_cache: |
|
next_cache = next_decoder_cache |
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [main_hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] |
|
if v is not None |
|
) |
|
return MoeV2ModelOutputWithPast( |
|
last_hidden_state=main_hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
mtp_hidden_states=mtp_hidden_states, |
|
attentions=all_self_attns, |
|
router_logits=all_router_logits, |
|
) |
|
|
|
|
|
class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config: BailingMoeV2Config): |
|
super().__init__(config) |
|
self.model = BailingMoeV2Model(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
self.num_nextn_predict_layers = config.num_nextn_predict_layers |
|
self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.word_embeddings = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
@add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=MoEV2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_router_logits: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[Tuple, MoEV2CausalLMOutputWithPast]: |
|
r""" |
|
Args: |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer |
|
|
|
>>> model = BailingMoeV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
output_router_logits = ( |
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_router_logits=output_router_logits, |
|
return_dict=return_dict, |
|
**kwargs, |
|
) |
|
|
|
loss = None |
|
all_mtp_loss = None |
|
aux_loss = None |
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
logits = logits.float() |
|
|
|
if labels is not None: |
|
loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs) |
|
|
|
all_mtp_logits = None |
|
if self.num_nextn_predict_layers > 0: |
|
mtp_hidden_states = outputs.mtp_hidden_states |
|
shift_labels_mtp = None |
|
for i in range(self.num_nextn_predict_layers): |
|
mtp_hidden_states = mtp_hidden_states[i] |
|
mtp_logits = self.lm_head(mtp_hidden_states).float() |
|
if all_mtp_logits is None: |
|
all_mtp_logits = [] |
|
all_mtp_logits.append(mtp_logits) |
|
if labels is not None: |
|
if shift_labels_mtp is None: |
|
shift_labels_mtp = labels.clone() |
|
shift_labels_mtp, _ = roll_tensor(shift_labels_mtp, shifts=-1, dims=-1, fill_value=-100) |
|
mtp_logits_ = mtp_logits.view(-1, self.config.vocab_size) |
|
mtp_loss = self.loss_function(mtp_logits_, shift_labels_mtp.to(mtp_logits_.device).view(-1), self.config.vocab_size, **kwargs) |
|
if loss is not None: |
|
loss += self.mtp_loss_scaling_factor * mtp_loss |
|
else: |
|
loss = self.mtp_loss_scaling_factor * mtp_loss |
|
|
|
if all_mtp_loss is None: |
|
all_mtp_loss = [] |
|
all_mtp_loss.append(mtp_loss) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
if output_router_logits: |
|
output = (aux_loss,) + output |
|
return (loss,) + output if loss is not None else output |
|
|
|
return MoEV2CausalLMOutputWithPast( |
|
loss=loss, |
|
mtp_loss=all_mtp_loss, |
|
aux_loss=aux_loss, |
|
logits=logits, |
|
mtp_logits=all_mtp_logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
router_logits=outputs.router_logits, |
|
) |
|
|
|
|