diff --git "a/modeling_ernie_45t_vl.py" "b/modeling_ernie_45t_vl.py" new file mode 100644--- /dev/null +++ "b/modeling_ernie_45t_vl.py" @@ -0,0 +1,4210 @@ +# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ernie VL model""" +import re +import math +import itertools +from dataclasses import dataclass +from collections import defaultdict +from copy import deepcopy +from functools import partial +from typing import List, Optional, Tuple, Union + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import ModelOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from .configuration_ernie_45t_vl import ( + DFNRopeVisionTransformerConfig, + Ernie4_5_MoEConfig, + Ernie4_5_VLMoEConfig, +) + +logger = logging.get_logger(__name__) + + +__all__ = [ + "Ernie4_5_VLMoeForConditionalGeneration", + "DFNRopeVisionTransformerPreTrainedModel", + "VariableResolutionResamplerModel", +] + + +class TokenType: + """token type definition""" + + text = 0 + image = 1 + video = 2 + + +class UniqueNameGuard: + """name guard""" + + def __init__(self, prefix=""): + self.prefix = prefix + self.counter = {} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def get_unique_name(self, name): + """get unique name""" + if name not in self.counter: + self.counter[name] = 0 + else: + self.counter[name] += 1 + return f"{self.prefix}{name}_{self.counter[name]}" + + +class RopeEmbedding(nn.Module): + """ + Rotary Position Embedding (RoPE) implementation for transformer models. + + RoPE encodes absolute positional information with rotation matrices and + naturally incorporates relative position information in self-attention. + + Args: + head_dim (int): Dimension size of each attention head + compression_ratio (float, optional): Sequence length compression ratio. Defaults to 1.0. + base (int, optional): Base value for frequency calculation. Defaults to 10000. + + Attributes: + head_dim (int): Dimension size of each attention head + compression_ratio (float): Sequence length compression factor + base (int): Base value for frequency calculation + """ + + def __init__(self, head_dim, compression_ratio=1.0, base=10000, freq_allocation=0): + """ + Initialize RoPE embedding layer. + + Args: + head_dim: Dimension of each attention head + compression_ratio: Scaling factor for position indices + base: Base value for frequency calculation + """ + super().__init__() + self.head_dim = head_dim + self.compression_ratio = compression_ratio + self.base = base + + # num of freq allocated to time + self.freq_allocation = freq_allocation + + def forward(self, seq_length, position_ids=None): + """ + Compute rotary position embeddings for given sequence length. + + Args: + seq_length (int): Maximum sequence length + position_ids (Tensor, optional): Custom position indices. Defaults to None. + + Returns: + Tensor: Rotary position embeddings of shape [1, 1, seq_length, head_dim] + """ + indices = torch.arange(0, self.head_dim, 2, dtype=torch.float32) + indices = 1 / self.base ** (indices / self.head_dim) + if position_ids is None: + position_ids = torch.arange( + 0, seq_length, 1, dtype=torch.float32 + ).unsqueeze(1) + position_ids = position_ids / self.compression_ratio + sinusoid_inp = position_ids * indices.unsqueeze(0) + else: + position_ids = position_ids / self.compression_ratio + seq_length = position_ids.shape[-1] + sinusoid_inp = position_ids.unsqueeze(-1).to( + torch.float32 + ) * indices.unsqueeze(0) + pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1) + pos_emb = pos_emb.view(-1, 1, seq_length, self.head_dim) + pos_emb = pos_emb.detach() + return pos_emb + + def apply_rotary(self, rp, q, k): + """ + Apply rotary position embeddings to queries and keys. + + Args: + rp (Tensor): Rotary position embeddings + q (Tensor): Query tensor [batch, heads, seq_len, dim] + k (Tensor): Key tensor [batch, heads, seq_len, dim] + + Returns: + Tuple[Tensor, Tensor]: Rotated queries and keys + """ + sin, cos = torch.chunk(rp, 2, dim=-1) + # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + sin_pos = torch.stack([sin, sin], dim=-1).reshape(rp.shape) + # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + cos_pos = torch.stack([cos, cos], dim=-1).reshape(rp.shape) + # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] + rotate_half_q = torch.stack( + [-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1 + ).reshape(q.shape) + query = (q.to(torch.float32) * cos_pos) + ( + rotate_half_q.to(torch.float32) * sin_pos + ) + # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + rotate_half_k = torch.stack( + [-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1 + ).reshape(k.shape) + key = (k.to(torch.float32) * cos_pos) + ( + rotate_half_k.to(torch.float32) * sin_pos + ) + return query, key + + def apply_rotary_3d(self, rp, q, k, position_ids): + """ + rope 3d rotary + + args: + rp: [1, max_seqlen, 1, head_dim] + q: [bsz, seqlen, head, head_dim] + k: [bsz, seqlen, head, head_dim] + position_ids: [bsz, seqlen, 3] + """ + current_device = q.device + sin, cos = torch.chunk(rp, 2, axis=-1) + assert position_ids.shape[:1] == q.shape[:1] + batch_indices = torch.arange(end=position_ids.shape[0]) + batch_indices = batch_indices[..., None] + sin = sin.tile(position_ids.shape[0], 1, 1, 1).to(device=position_ids.device) + cos = cos.tile(position_ids.shape[0], 1, 1, 1).to(device=position_ids.device) + + assert self.freq_allocation != 0 + sin_t = sin[batch_indices, position_ids[..., 0], :, -self.freq_allocation :] + sin_h = sin[ + batch_indices, + position_ids[..., 1], + :, + : self.head_dim // 2 - self.freq_allocation : 2, + ] + sin_w = sin[ + batch_indices, + position_ids[..., 2], + :, + 1 : self.head_dim // 2 - self.freq_allocation : 2, + ] + sin_hw = torch.stack([sin_h, sin_w], dim=-1).reshape( + sin_h.shape[:-1] + (sin_h.shape[-1] * 2,) + ) + sin_thw = torch.cat([sin_hw, sin_t], dim=-1) + + cos_t = cos[batch_indices, position_ids[..., 0], :, -self.freq_allocation :] + cos_h = cos[ + batch_indices, + position_ids[..., 1], + :, + : self.head_dim // 2 - self.freq_allocation : 2, + ] + cos_w = cos[ + batch_indices, + position_ids[..., 2], + :, + 1 : self.head_dim // 2 - self.freq_allocation : 2, + ] + cos_hw = torch.stack([cos_h, cos_w], dim=-1).reshape( + cos_h.shape[:-1] + (cos_h.shape[-1] * 2,) + ) + cos_thw = torch.cat([cos_hw, cos_t], dim=-1) + + # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + sin_pos = ( + torch.stack([sin_thw, sin_thw], dim=-1) + .reshape(sin_thw.shape[:3] + (sin_thw.shape[-1] * 2,)) + .to(current_device) + ) + # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + cos_pos = ( + torch.stack([cos_thw, cos_thw], dim=-1) + .reshape(cos_thw.shape[:3] + (cos_thw.shape[-1] * 2,)) + .to(current_device) + ) + + # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] + rotate_half_q = torch.stack( + [-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1 + ).reshape(q.shape) + query = (q.to(torch.float32) * cos_pos) + ( + rotate_half_q.to(torch.float32) * sin_pos + ) + # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + rotate_half_k = torch.stack( + [-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1 + ).reshape(k.shape) + key = (k.to(torch.float32) * cos_pos) + ( + rotate_half_k.to(torch.float32) * sin_pos + ) + return query, key + + +class Ernie4_5_MLP(nn.Module): + """ + Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model. + """ + + def __init__(self, config, layer_idx=0): + """ + Initialize the MLP module with configuration options. + + Args: + config (Ernie4_5_Config): Model configurations. + layer_idx (int): Index of current layer (default: 0) + """ + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=config.use_bias + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=config.use_bias + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=config.use_bias + ) + + def forward(self, x): + """ + Forward pass through the MLP module. + + Args: + x (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + + Returns: + Tensor: Output tensor of shape [batch_size, seq_len, hidden_size] + """ + current_device = self.gate_proj.weight.data.device + x = x.to(current_device) + down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Ernie4_5_Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, layer_idx=0): + """Initialize the attention layer. + + Args: + config (Ernie4_5_Config): Model configuration. + layer_idx (int, optional): Index in transformer stack. Defaults to 0. + """ + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_heads + self.is_gqa = ( + self.num_key_value_heads is not None + and self.num_key_value_heads != self.num_heads + ) + + self.freq_allocation = getattr(config, "freq_allocation", 0) + assert ( + self.freq_allocation is not None + ), "freq_allocation must be provided if rope_3d is on." + + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + if self.is_gqa: + assert ( + self.num_key_value_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_key_value_heads = ( + self.num_key_value_heads // config.tensor_parallel_degree + ) + q_hidden_size = self.head_dim * self.num_heads + if self.is_gqa: + logger.info( + f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}" + ) + assert ( + self.num_heads % self.num_key_value_heads == 0 + ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}" + kv_hidden_size = self.head_dim * self.num_key_value_heads + else: + kv_hidden_size = self.head_dim * self.num_heads + + self.q_proj = nn.Linear(self.hidden_size, q_hidden_size, bias=config.use_bias) + self.k_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias) + self.v_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias) + + self.o_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias=config.use_bias, + ) + + self.rotary_emb = RopeEmbedding( + self.head_dim, + compression_ratio=config.compression_ratio, + base=config.rope_theta, + freq_allocation=self.freq_allocation, + ) + self.config = config + self.attn_func = self.core_attn + + def forward( + self, + hidden_states, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + attn_mask_start_row_indices: Optional[torch.Tensor] = None, + position_ids: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + token_type_ids: Optional[Tuple[torch.Tensor]] = None, # MLLM + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Compute attention outputs. + + Args: + hidden_states (torch.Tensor): Input tensor [bsz, seq_len, hidden_size] + past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached key/value states + attention_mask (Optional[torch.Tensor]): Attention mask tensor + attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices + position_ids (Optional[torch.Tensor]): Position indices for RoPE + output_attentions (bool): Return attention weights if True + use_cache (bool): Cache key/value states if True + + Returns: + Tuple containing: + - attention_output: [bsz, seq_len, hidden_size] + - attention_weights: Optional attention probabilities + - updated_key_value_cache: Optional updated cache + """ + if token_type_ids is not None: + token_type_ids = token_type_ids[:, :-1] + + bsz, q_len, _ = hidden_states.shape + query_states = self.q_proj(hidden_states).reshape( + [bsz, q_len, -1, self.head_dim] + ) + key_states = self.k_proj(hidden_states).reshape([bsz, q_len, -1, self.head_dim]) + value_states = self.v_proj(hidden_states).reshape( + [bsz, q_len, -1, self.head_dim] + ) + + attn_output, attn_weights, past_key_value = self.rope_attn( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_start_row_indices=attn_mask_start_row_indices, + ) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def repeat_kv(self, hidden_states, n_rep): + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def core_attn( + self, + q, + k, + v, + attention_mask=None, + attn_mask_start_row_indices=None, + seq_length=None, + ): + """Standard self-attention implementation. + + Args: + q (torch.Tensor): Query tensor + k (torch.Tensor): Key tensor + v (torch.Tensor): Value tensor + attention_mask (Optional[torch.Tensor]): Attention mask + attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices + seq_length (Optional[int]): Sequence length + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Attention output and weights + """ + origin_dtype = q.dtype + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + scale_qk_coeff = getattr(self.config, "scale_qk_coeff", 1.0) * ( + self.head_dim**0.5 + ) + + q = q / scale_qk_coeff + + # Handle GQA case - repeat k and v heads to match q heads + if self.is_gqa: + # [batch, num_key_value_heads, seq_len, head_dim] -> [batch, num_heads, seq_len, head_dim] + repeat_factor = self.num_heads // self.num_key_value_heads + k = self.repeat_kv(k, repeat_factor) + v = self.repeat_kv(v, repeat_factor) + + product = torch.matmul(q, k.transpose(-2, -1)) + + product = product.to(torch.float32) + if getattr(self.config, "scale_qk_coeff", 1.0) != 1.0: + product = product * getattr(self.config, "scale_qk_coeff", 1.0) + + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(torch.float32) + product = product + attention_mask + weights = F.softmax(product, dim=-1) + else: + seq_len = product.size(-1) + mask = torch.triu( + torch.ones((seq_len, seq_len), dtype=torch.bool, device=product.device), + diagonal=1, + ) + product = product.masked_fill(mask, float("-inf")) + weights = F.softmax(product, dim=-1) + + weights = weights.to(origin_dtype) + + if getattr(self.config, "attention_probs_dropout_prob", 0.0) > 0: + weights = F.dropout( + weights, + self.config.attention_probs_dropout_prob, + training=self.training, + ) + + out = torch.matmul(weights, v) + + # combine heads + out = out.permute(0, 2, 1, 3) + out = out.contiguous().view(out.size(0), out.size(1), -1) + + return out, weights + + def rope_attn( + self, + query_states, + key_states, + value_states, + attention_mask, + position_ids, + output_attentions=False, + past_key_value=None, + use_cache=False, + attn_mask_start_row_indices=None, + ): + """Attention computation with rotary embeddings. + + Args: + mix_layer (Optional[torch.Tensor]): Combined QKV projection + query_states (torch.Tensor): Query states + key_states (torch.Tensor): Key states + value_states (torch.Tensor): Value states + attention_mask (Optional[torch.Tensor]): Attention mask + position_ids (Optional[torch.Tensor]): Position indices + output_attentions (bool): Return attention weights + past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached states + use_cache (bool): Cache new states + attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices + + Returns: + Tuple containing: + - attention_output: Result tensor + - attention_weights: Optional weights + - updated_key_value_cache: Optional cache + """ + + query_states_dtype = query_states.dtype + + assert position_ids is not None, "rope3d requires pos-id" + kv_seq_len = position_ids.max() + 1 + offset = 0 + if past_key_value is not None: + offset = position_ids.max() + kv_seq_len = position_ids.max() + 1 + position_ids = position_ids[:, -1:, :] + + cos_sin = self.rotary_emb(kv_seq_len).permute([0, 2, 1, 3]) + if offset > 0 and position_ids is None: + cos_sin = cos_sin[:, offset:] + query_states, key_states = self.rotary_emb.apply_rotary_3d( + cos_sin, query_states, key_states, position_ids + ) + + query_states = query_states.to(query_states_dtype) + key_states = key_states.to(query_states_dtype) + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + + # shape: [2, b, s, kvh, d] + past_key_value = [key_states, value_states] if use_cache else None + seq_length = query_states.shape[1] + attn_output, attn_weights = self.attn_func( + query_states, + key_states, + value_states, + attention_mask, + attn_mask_start_row_indices, + seq_length, + ) + + return attn_output, attn_weights, past_key_value + + +class FusedDropoutImpl(nn.Module): + """ + Fused dropout implementation with residual connection support. + + This layer combines dropout and residual addition in a single operation for better performance, + particularly on GPU devices. The dropout is conditionally applied based on the probability. + + Args: + prob (float): Dropout probability (between 0 and 1) + mode (str): Dropout mode, either 'upscale_in_train' or 'downscale_in_infer' + + Attributes: + prob (float): Stores the dropout probability + mode (str): Stores the dropout mode + dropout (nn.Dropout): The actual dropout layer instance + """ + + def __init__(self, prob, mode): + """ + Initialize the fused dropout layer. + + Args: + prob (float): Dropout probability (0 means no dropout) + mode (str): Dropout mode ('upscale_in_train' or 'downscale_in_infer') + """ + super().__init__() + self.prob = prob + self.dropout = nn.Dropout(p=prob) + + def forward(self, x, y): + """ + Forward pass of the fused dropout layer. + + Args: + x (Tensor): Input tensor to potentially apply dropout on + y (Tensor): Residual tensor to add to the (possibly dropped out) x + + Returns: + Tensor: Result of x (with optional dropout) + y + """ + if self.prob > 0: + x = self.dropout(x) + output = x + y + + return output + + +class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization (RMSNorm) implementation. + + RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs, + omitting the mean-centering operation. This provides computational efficiency while maintaining + good performance. + + """ + + def __init__(self, config): + """ + Initialize RMSNorm layer. + + Args: + config (Ernie4_5_Config): Model configuration. + """ + super().__init__() + self.hidden_size = config.hidden_size + self.weight = nn.Parameter( + torch.ones(self.hidden_size, dtype=torch.get_default_dtype()) + ) + self.variance_epsilon = config.rms_norm_eps + + def forward(self, hidden_states): + """ + Apply RMS normalization to input hidden states. + + Args: + hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + + Returns: + Tensor: Normalized output tensor of same shape as input + + Note: + - computes RMSNorm manually: + 1. Compute variance of features + 2. Apply reciprocal square root normalization + 3. Scale by learned weight parameter + - Maintains original dtype for numerical stability during computation + """ + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = torch.rsqrt(variance + self.variance_epsilon) * hidden_states + return hidden_states.to(self.weight.dtype) * self.weight + + +class Ernie4_5_MoeMLP(Ernie4_5_MLP): + """Mixture of Experts (MoE) variant of ERNIE's MLP layer.""" + + def __init__(self, config, layer_idx=0): + """Initialize the MoE MLP layer. + + Args: + config (Ernie4_5_MoEConfig): Configuration for MoE architecture. + layer_idx (int): Index of current layer in transformer stack + """ + + if getattr(config, "disable_ffn_model_parallel", False): + config = deepcopy(config) + config.tensor_parallel_degree = 1 + + super().__init__(config, layer_idx=layer_idx) + self.moe_dropout_prob = config.moe_dropout_prob + + def forward(self, x): + """Forward pass through MoE MLP layer. + + Args: + x (paddle.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + or [seq_len, hidden_size] + + Returns: + paddle.Tensor: Output tensor with same shape as input + """ + current_device = self.gate_proj.weight.data.device + x = x.to(current_device) + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + if self.moe_dropout_prob > 0: + x = F.dropout(input=x, p=self.moe_dropout_prob) + ret = self.down_proj(x) + return ret + + +def masked_fill(x, mask, value): + """ + Fills elements of the input tensor with a given value where mask is True. + """ + return torch.where(mask, torch.full_like(x, value), x) + + +def _squared_l2_norm(x: torch.Tensor) -> torch.Tensor: + """Computes 0.5 * sum(x^2)""" + return 0.5 * torch.sum(x * x) + + +@torch.no_grad() +def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): + """ + Computes optimal transport matrix and Sinkhorn distance using Sinkhorn-Knopp algorithm. + """ + n, _ = M.shape + P = F.softmax(-M / lam, dim=1) # Applying softmax over columns + u = torch.zeros(n, dtype=torch.float32, device=M.device) + + for _ in range(max_iters): + P_sum_1 = P.sum(1) + if (u - P_sum_1).abs().max() < epsilon: + break + u = P_sum_1 + P *= (r / (u + 1e-8)).unsqueeze(1) + P *= (c / (P.sum(0) + 1e-8)).unsqueeze(0) + + P = torch.where(~P.isnan(), P, torch.zeros_like(P)) + return P, _ + + +class Top2Gate(nn.Module): + """ + Gate module implementing Top2Gating as described in Gshard paper. + """ + + def __init__(self, config, layer_idx: int, group=None, gate_weight=None) -> None: + """ + Initialize the MoE (Mixture of Experts) layer. + + Args: + config: Model configuration containing MoE parameters + layer_idx: Index of this layer in the model + group: Distributed communication group + gate_weight: Optional pre-existing gate weight tensor + """ + super().__init__() + self.config = config + + self.model_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.num_experts_tensor = ( + sum(config.moe_num_experts) + if config.multimodel_experts + else config.moe_num_experts + ) + + self.cap = config.moe_capacity + self.group = group + + self.layer_idx = layer_idx + + self.sinkhorn_2gate = config.sinkhorn_2gate + self.sinkhorn_temp = config.sinkhorn_temp + self.use_correction_bias = config.moe_use_aux_free # true + self.use_token_type_bias = config.get("moe_use_token_type_bias", False) + + self.act = partial(F.softmax, dim=-1) # [S,E] + + self.no_jitter = True + self.expert_drop = False + self.eye_matrix = None + self.eye_matrix_size = None + self.norm_gate_logits = config.moe_norm_gate_logits # true + self.one = torch.ones([], dtype=torch.float32) + + self.moe_aux_loss_lambda = torch.tensor(config.moe_aux_loss_lambda).to( + dtype=torch.float32 + ) + self.moe_z_loss_lambda = torch.tensor(config.moe_z_loss_lambda).to( + dtype=torch.float32 + ) + self.moe_orthogonal_loss_lambda = torch.tensor( + config.moe_orthogonal_loss_lambda + ).to(dtype=torch.float32) + + if self.moe_aux_loss_lambda.ndim == 0: + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) + if self.moe_z_loss_lambda.ndim == 0: + self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0) + if self.moe_orthogonal_loss_lambda.ndim == 0: + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze( + 0 + ) + + self.experts_type_ids = None + + self.eps = torch.tensor([1e-12]).to(dtype=torch.float32) + if config.multimodel_experts: + if config.get("moe_use_hard_gate", False): + self.num_experts_list = [] + self.experts_type_mask = [] + # hard-gate + group_experts 需要对gate_logits不同部分分开计算 + experts_ids = torch.zeros( + [sum(self.num_experts)], dtype=torch.int64 + ).reshape((1, -1)) + offset = 0 + for i, expert_num in enumerate(self.num_experts): + experts_ids[:, offset : offset + expert_num] = i + offset += expert_num + self.experts_type_ids = experts_ids.reshape([-1]) + logger.info( + f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}" + ) + for i, expert_num in enumerate(self.num_experts): + self.experts_type_mask.append( + self.experts_type_ids == i, + ) + self.num_experts_list.append(expert_num) + else: + # 非group_experts, 依赖token_type_bias实现hard-gate能力。 + assert ( + not config.moe_group_experts + ), "group_experts must use hard_gate when multimodel_experts is True" + else: + self.num_experts_list = [self.num_experts] + + if gate_weight is not None: + self.weight = gate_weight + + assert ( + not self.config.moe_use_token_type_bias + ), "gate_weights is from outside, token_type_bias can't be used" + logger.info("moe use gate_weight from outside") + # use fp32 pecison in amp + self._cast_to_low_precision = False + self._cast_to_low_precison = False + else: + self._create_gate_parameter() + logger.info( + f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " + f"use_token_type_bias:{self.use_token_type_bias} " + f"gate_act:{config.moe_gate_act} " + f"norm_gate_logits={self.norm_gate_logits} use_correction_bias={self.use_correction_bias}" + ) + + def _create_gate_parameter(self): + """ + Create gate weight parameter. + """ + if self.config.multimodel_experts: + # support setting lambda for each expert group + self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand( + len(self.num_experts) + ) + + for i, num_experts in enumerate(self.num_experts): + if i == 1: + with UniqueNameGuard(f"mm_gate_{self.layer_idx}_"): + p = nn.Parameter( + torch.empty( + self.model_dim, + num_experts, + dtype=torch.float32, + device="cpu", + ) + ) + nn.init.xavier_uniform_(p) # Common initialization + else: + p = nn.Parameter( + torch.empty( + self.model_dim, + num_experts, + dtype=torch.float32, + device="cpu", + ) + ) + nn.init.xavier_uniform_(p) # Common initialization + self.register_parameter( + "weight" if i == 0 else f"weight_{i}", + p, + ) + else: + self.weight = nn.Parameter( + torch.empty(self.model_dim, self.num_experts, dtype=torch.float32) + ) + nn.init.xavier_uniform_(self.weight) # Common initialization + # use fp32 pecison in amp + self._cast_to_low_precision = False + self._cast_to_low_precison = False + + def get_gate_weight(self, transform_weight, is_multimodel=True): + """ + 在`multimodel_experts` 的情况下,将多个 weights merge 成一个整体 + transform_weight: bool, 按照 local-expert id 将 多模态 weight 交叠 + """ + if not is_multimodel or not self.config.multimodel_experts: + return self.weight + else: + return torch.cat( + [ + getattr(self, "weight" if i == 0 else f"weight_{i}") + for i in range(len(self.num_experts)) + ], + -1, + ) + + def forward( + self, + input: torch.Tensor, + token_type_ids: torch.Tensor = None, + transform_weight: bool = True, + correction_bias: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass through the gate. + + Args: + input: Input tensor of shape [Seq, Dim] + token_type_ids: Token type IDs tensor of shape [Seq] + transform_weight: Whether to transform weights for multimodal experts + correction_bias: Bias tensor for correction + + Returns: + tuple: (capacity, dispatch_mask, combine_weights, scatter_index, router_loss, logits) + """ + orig_dtype = input.dtype + current_device = input.device + weight = self.get_gate_weight(transform_weight) + + logits = F.linear( + input.to(dtype=torch.float32, device=current_device), + weight.T.to(dtype=torch.float32, device=current_device), + ) + + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + l_aux, + l_zloss, + ) = self.top2_gating( + logits, + correction_bias=( + correction_bias.to(device=current_device) + if correction_bias is not None + else None + ), + ) + + combine_weights = combine_weights.to(orig_dtype) + return capacity, dispatch_mask, combine_weights, scatter_index, None, logits + + def get_capacity(self, num_tokens, cap_factor=None, is_multimodel=True): + """ + Calculate capacity based on number of tokens. + + Args: + num_tokens: Number of input tokens + cap_factor: Optional capacity factor override + + Returns: + int: Calculated capacity + """ + if is_multimodel and self.config.multimodel_experts: + num_experts = sum(self.num_experts_list) + elif isinstance(self.num_experts, (list, tuple)): + num_experts = self.num_experts[0] + else: + num_experts = self.num_experts + if cap_factor is not None: + cap = cap_factor + else: + if self.training: + cap = self.cap[0] + elif num_tokens < num_experts: # seqlen < num_expert + cap = self.cap[2] + else: + cap = self.cap[1] + # capacity = 2S/E + capacity = int(cap * num_tokens // num_experts) + assert ( + capacity > 0 + ), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" + return capacity + + def top2_gating(self, logits, cap=None, correction_bias=None): + """ + Implement Top2 gating mechanism. + + Args: + logits: Input logits tensor + cap: Optional capacity override + correction_bias: Bias tensor for correction + + Returns: + tuple: (capacity, dispatch_masks, combine_weights, scatter_indexes, loss_aux, loss_z) + + Note: + capacity: The maximum number that each token can be dispatched. + dispatch_masks: Masks used for dispatching. The first element is the mask for the first + type of tokens; the second element is the mask for the second type of tokens. + combine_weights: Weights used for combining. The first element is the weight for the first + type of tokens; the second element is the weight for the second type of tokens. + scatter_indexes: Indexes used for scattering. The first element is the index for the first + type of tokens; the second element is the index for the second type of tokens. + loss_aux: Auxiliary loss. + loss_z: Z loss. + """ + gates = self.act(logits) + + # gates has shape of SE + assert logits.ndim == 2, logits.shape + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + # capacity = 2S/E + capacity = self.get_capacity(logits.shape[0], cap) + current_device = logits.device + + # Create a mask for 1st's expert per token + score_for_argmax = ( + gates + correction_bias.unsqueeze(0) + if correction_bias is not None + else gates + ) + indices1_s = torch.argmax(score_for_argmax, dim=1) + mask1 = F.one_hot(indices1_s, num_classes=num_experts).to( + dtype=torch.int64, device=current_device + ) # [0,1] + + # Create a mask for 2nd's expert per token using Gumbel-max trick + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + if self.training and not self.no_jitter: + gumbels = ( + -torch.empty_like( + logits, + device=current_device, + ) + .exponential_() + .log() + ) # ~Gumbel(0,1) + logits_w_noise = logits + gumbels + else: + logits_w_noise = logits + + logits_except1 = masked_fill( + logits_w_noise, + mask1.to(dtype=torch.bool, device=current_device), + float("-inf"), + ) + score_for_argmax = ( + self.act(logits_except1) + correction_bias.unsqueeze(0) + if correction_bias is not None + else logits_except1 + ) + indices2_s_original = torch.argmax(score_for_argmax, dim=1) + + if self.training and self.sinkhorn_2gate: + r = ( + torch.ones(num_tokens, dtype=torch.float32, device=current_device) + / num_tokens + ) + c_mask_sum = mask1.to(dtype=torch.float32, device=current_device).sum(0) + c = capacity - c_mask_sum + c = torch.maximum(c, torch.zeros_like(c, device=current_device)) + c_sum = c.sum() + if c_sum > 0: + c = c / c_sum + else: # Avoid division by zero if all experts are full from top-1 + c = torch.ones_like(c, device=current_device) / num_experts + + pi, _ = compute_optimal_transport( + -logits_except1.to(dtype=torch.float32, device=current_device).detach(), + r, + c, + lam=self.sinkhorn_temp, + ) + pi = masked_fill( + pi, mask1.to(dtype=torch.bool, device=current_device), float("-inf") + ) + indices2_s = torch.argmax(pi, dim=1) + else: + indices2_s = indices2_s_original + + mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).to( + dtype=torch.int64, device=current_device + ) + + # Compute locations in capacity buffer + locations1 = ( + torch.cumsum(mask1, dim=0) - 1 + ) # [0,1,1,0,1,0,0] -> [0,0,0,0,1,1,1,] + locations2 = torch.cumsum(mask2, dim=0) - 1 + # Update 2nd's location by accounting for locations of 1st + locations2 += torch.sum(mask1, dim=0, keepdim=True) + + # Remove locations outside capacity from mask + mask1 = mask1 * (locations1 < capacity).to( + dtype=torch.int64, device=current_device + ) # [0,1,1,0,0,0,0] + mask2 = mask2 * (locations2 < capacity).to( + dtype=torch.int64, device=current_device + ) + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + locations2_s = torch.sum(locations2 * mask2, dim=1) + + # Normalize gate probabilities + mask1_float = mask1.to(dtype=torch.float32, device=current_device) + mask2_float = mask2.to(dtype=torch.float32, device=current_device) + gates1_s = (gates * mask1_float).sum(dim=-1) + gates2_s = (gates * mask2_float).sum(dim=-1) + # logger.info(f'gates1_s:{gates1_s} gates2_s:{gates2_s} logits:{logits}') + + if self.norm_gate_logits: + denom_s = gates1_s + gates2_s # [0.2, 0.3] + # Avoid divide-by-zero + denom_s = torch.clamp(denom_s, min=1e-6) + gates1_s /= denom_s + gates2_s /= denom_s + if self.training and self.expert_drop: + # log.debug(gates2_s) + gates2_s = torch.where( + 2 * gates2_s < torch.rand_like(gates2_s, device=current_device), + torch.zeros_like(gates2_s, device=current_device), + gates2_s, + ) + + # Calculate combine_weights and dispatch_mask + gates1 = gates1_s.unsqueeze(1) * mask1_float + gates2 = gates2_s.unsqueeze(1) * mask2_float + + combine1_weight, expert1_index = torch.max(gates1, dim=-1, keepdim=True) + scatter1_index = expert1_index.squeeze(-1) * capacity + locations1_s + scatter1_index = scatter1_index.to(dtype=torch.int64, device=current_device) + dispatch1_mask = combine1_weight.to( + dtype=torch.bool, device=current_device + ).detach() + + combine2_weight, expert2_index = torch.max(gates2, dim=-1, keepdim=True) + scatter2_index = expert2_index.squeeze(-1) * capacity + locations2_s + scatter2_index = scatter2_index.to(dtype=torch.int64, device=current_device) + dispatch2_mask = combine2_weight.to( + dtype=torch.bool, device=current_device + ).detach() + # logger.info(f'expert-id: {expert1_index} vs {expert2_index}, mask:{mask1_float} vs {mask2_float}') + + return ( + capacity, + torch.cat((dispatch1_mask, dispatch2_mask), 1), + torch.cat((combine1_weight, combine2_weight), 1), + torch.stack((scatter1_index, scatter2_index), 1), + None, + None, + ) + + def _cal_orthogonal_loss_opt_each_weight(self, weight, use_group): + """ + Calculate optimized orthogonal loss for each weight. + + Args: + weight: Weight tensor + use_group: Whether to use expert groups + + Returns: + Tensor: Calculated orthogonal loss + """ + if weight.dtype != torch.float32: + weight = weight.to(torch.float32) + + wnorm = torch.norm(weight, p=2, dim=1) + weight = weight / torch.maximum(wnorm, self.eps.to(weight.device)).unsqueeze(1) + + if use_group: + weight = weight.reshape( + [self.config.moe_k, -1, weight.shape[1]] + ) # [K, E/K, H] + eye_matrix = torch.eye( + weight.shape[1], dtype=weight.dtype, device=weight.device + ).unsqueeze(0) + else: + eye_matrix = torch.eye( + weight.shape[0], dtype=weight.dtype, device=weight.device + ) + + weight_matmul = torch.matmul(weight, weight.T) + + orthogonal_loss = weight_matmul - eye_matrix + orthogonal_loss = _squared_l2_norm(orthogonal_loss) / ( + orthogonal_loss.size(0) * orthogonal_loss.size(1) + ) + return orthogonal_loss + + +class TopKGate(Top2Gate): + """ + Fused version of TopK gate for improved performance. + """ + + def forward( + self, + input: torch.Tensor, + token_type_ids=None, + transform_weight=True, + is_multimodel=True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass for fused gate. + + Args: + input: Input tensor + token_type_ids: Token type IDs + transform_weight: Whether to transform weights + + Returns: + tuple: (logits, capacity, router_loss) + """ + current_device = input.device + weight = self.get_gate_weight(transform_weight, is_multimodel=is_multimodel) + + logits = F.linear( + input.to(dtype=torch.float32, device=current_device), + weight.T.to(dtype=torch.float32, device=current_device), + ) + if self.use_token_type_bias: + assert token_type_ids is not None + assert ( + token_type_ids.max() < self.bias.shape[0] + ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" + bias = self.bias[token_type_ids] # [seq] + logits = logits + bias + + return logits + + +gate_class = dict( + top2=Top2Gate, + topk=TopKGate, +) + + +def get_gate( + config: Ernie4_5_MoEConfig, + expert: nn.Module, + layer_idx: int, +) -> Tuple[nn.Module, nn.ModuleList]: + """Initialize and distribute MoE (Mixture of Experts) components. + + Creates gate layer and distributed expert network for MoE architecture. + + Args: + config (Ernie4_5_MoEConfig): Configuration for MoE architecture + expert (nn.Module): Prototype expert network to be replicated + layer_idx (int): Index of current layer in transformer stack + + Returns: + Tuple[nn.Module, nn.ModuleList]: + - gate: Initialized gate layer for routing + - experts: ModuleList containing expert networks + """ + moe_num_experts = ( + sum(config.moe_num_experts) + if config.multimodel_experts + else config.moe_num_experts + ) + experts = nn.ModuleList([]) + + for expert_id, (experts_num, fc) in enumerate(expert): + experts_to_append = [] + if not hasattr(fc, "__len__"): # run this + experts_to_append.append(fc) + if expert_id == 1: + with UniqueNameGuard("_mm_deepcopy"): + for _ in range(experts_num - 1): + experts_to_append.append(deepcopy(fc)) + else: + for _ in range(experts_num - 1): + experts_to_append.append(deepcopy(fc)) + else: + experts_to_append = fc + for ex in experts_to_append: + for p in ex.parameters(): + p.expert_type = f"expert_type_{expert_id}" # Different `expert_type` can have different intermediate-size + index = 0 + for i in range(experts_num): + if i // experts_num == 0: + experts.append(experts_to_append[index]) + index += 1 + else: + experts.append(None) + + assert ( + len(experts) == moe_num_experts + ), f"experts.len={len(experts)} != experts_num={experts_num}" + logger.info(f"MOE-GATE:-{config.moe_gate}") + + gate = gate_class[config.moe_gate.lower()](config, layer_idx=layer_idx) + + if config.multimodel_experts and config.moe_use_hard_gate and moe_num_experts > 2: + lm_experts = experts[: config.moe_num_experts[0]] + lm_gate = gate + else: + if config.multimodel_experts and config.moe_use_hard_gate: + lm_gate, lm_experts = gate, experts + else: + lm_gate, lm_experts = None, None + + logger.info(f"LM-experts-{lm_experts} -- experts-{experts}") + + return gate, experts, lm_gate, lm_experts + + +class MoEStatics(nn.Module): + """ + Stores MoE (Mixture of Experts) statistics + and expert usage information. + """ + + def __init__(self, config, layer_idx): + """ + Initialize MoE statistics tracking. + + Args: + config: Model configuration containing MoE parameters + layer_idx: Index of the MoE layer in the model + """ + super().__init__() + self._cast_to_low_precision = False + self._cast_to_low_precison = False + num_experts = ( + config.moe_num_experts[0] + if config.multimodel_experts + else config.moe_num_experts + ) + if config.multimodel_experts: + assert ( + len(set(config.moe_num_experts)) == 1 + ), "assume expert group has same size, got: {config.moe_num_experts}" + + with UniqueNameGuard(f"mm_layer_{layer_idx}_"): + num_experts_groups = ( + len(config.moe_num_experts) if config.multimodel_experts else 1 + ) + p = nn.Parameter( + torch.zeros(num_experts_groups, num_experts, dtype=torch.float32), + requires_grad=False, + ) + self.e_score_correction_bias = p + p = torch.zeros(num_experts_groups, num_experts, dtype=torch.int64) + self.expert_usage = p + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + """ + Reorders input tensor based on gate results with capacity truncation and padding. + + Args: + x (Tensor): Input tensor of shape [Seq, Dim] + dispatch_mask (Tensor): Dispatching mask of shape [Seq, 2] + scatter_index (Tensor): Scatter indices of shape [Seq, 2] + num_experts (int): Number of experts + capacity (int): Capacity per expert + + Returns: + Tensor: Dispatched output tensor of shape [Expert*Capacity, Dim] + """ + output = None + orig_dtype = x.dtype + scatter_index_unbound = [scatter_index[:, 0], scatter_index[:, 1]] + dispatch_mask_unbound = [dispatch_mask[:, 0], dispatch_mask[:, 1]] + + for i_scatter_index, i_dispatch_mask in zip( + scatter_index_unbound, dispatch_mask_unbound + ): + updates = x * i_dispatch_mask.unsqueeze(-1).to(orig_dtype) # [seq, dim] + init_output = torch.zeros( + num_experts * capacity, x.shape[-1], dtype=orig_dtype, device=x.device + ) + + index = i_scatter_index.unsqueeze(-1).expand(-1, x.shape[-1]) # [seq, dim] + if output is None: + output = init_output.scatter_add(0, index, updates) + else: + output = output + init_output.scatter_add(0, index, updates) + if output.dtype != orig_dtype: + output = output.to(orig_dtype) + return output + + +def combining(x, combine_weights, scatter_index): + """ + Combines and aggregates input matrix using combination weights. + + Args: + x (Tensor): Input tensor of shape [num_experts * capacity, dim] + combine_weights (Tensor): Combination weights of shape [seq, 2] + scatter_index (Tensor): Scatter indices of shape [seq, 2] + + Returns: + Tensor: Combined output tensor of shape [seq, dim] + """ + dim = x.shape[-1] + + current_device = scatter_index.device + x = x.to(current_device) + scatter_index = scatter_index.reshape([-1]) + num_k = combine_weights.shape[-1] + + combine_weights = combine_weights.unsqueeze(1).to(current_device) + + x = x[scatter_index].reshape([-1, num_k, dim]) # [seq, 2, dim] + + return torch.matmul(combine_weights, x).squeeze( + 1 + ) # [seq, 1, 2] @ [seq, 2, dim] -> [seq, 1, dim] + + +class MOELayer(nn.Module): + """ + Mixture of Experts layer implementation based on GShard paper. + """ + + def __init__( + self, + gate: nn.Module, + experts: List[nn.Module], + layer_idx: int, + shared_experts: Optional[List[nn.Module]] = None, + group=None, + recompute: bool = False, + k: int = 2, + all_to_all_dropout: float = 0, + group_experts: bool = False, + moe_statics=None, + moe_num_experts=None, + ): + """ + Initialize MoE layer. + + Args: + gate: Gate network for expert selection + experts: List of expert networks + layer_idx: Index of this layer in the model + group: Distributed communication group + recompute: Whether to enable recomputation + k: Number of experts to select per token + all_to_all_dropout: Dropout rate for all-to-all communication + group_experts: Whether to group experts + moe_statics: MoE statistics tracking object + """ + super().__init__() + self.gate = gate + self.layer_idx = layer_idx + + if isinstance(experts, nn.ModuleList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + self.use_correction_bias = moe_statics is not None + self.moe_statics = moe_statics + if self.use_correction_bias: + logger.info( + f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}" + ) + assert self.gate.config.moe_use_aux_free + + try: + self.world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + except: + self.world_size = 1 + self.rank = 0 + if self.world_size < 1: + self.world_size = 1 + if self.rank < 0: + self.rank = 0 + + self.multimodal_experts = ( + isinstance(moe_num_experts, (tuple, list)) and len(moe_num_experts) > 1 + ) + self.num_local_experts = len(self.experts) // self.world_size + if self.multimodal_experts: + self.num_local_multimodal_experts = [ + num // self.world_size for num in moe_num_experts + ] + self.multimodal_expert_index = [0] + list( + itertools.accumulate(moe_num_experts) + ) + + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + self.config = self.gate.config + self.zero = torch.tensor(0).to(dtype=torch.float32) + + def forward_experts(self, dispatched_input): + """ + Forward pass through experts sequentially. + + Args: + dispatched_input: Input tensor of shape [num_experts, capacity, dim] + + Returns: + Tensor: Expert outputs of shape [num_experts, capacity, dim] + """ + + if not self.multimodal_experts: + true_experts = self.experts[ + self.rank + * self.num_local_experts : (self.rank + 1) + * self.num_local_experts + ] + else: + true_experts = [] + for i, num in enumerate(self.num_local_multimodal_experts): + current_modal_experts = self.experts[ + self.multimodal_expert_index[i] : self.multimodal_expert_index[ + i + 1 + ] + ] + true_experts.extend( + current_modal_experts[self.rank * num : (self.rank + 1) * num] + ) + + dispatched_input = dispatched_input.reshape( + [self.world_size, self.num_local_experts, -1, dispatched_input.shape[-1]] + ) + current_device = dispatched_input.device + expert_outputs = [] + if isinstance(self.experts, nn.ModuleList): + chunks = dispatched_input.permute(1, 0, 2, 3).contiguous().unbind(0) + assert len(chunks) == len( + true_experts + ), f"{len(chunks)}, {len(true_experts)}" + for chunk, expert in zip(chunks, true_experts): + expert_outputs.append(expert(chunk)) + else: + dispatched_input = dispatched_input.permute(1, 0, 2, 3).contiguous() + orig_shape = dispatched_input.shape + chunks = dispatched_input.reshape(orig_shape[0], -1, orig_shape[-1]) + chunks = self.experts(chunks) + chunks = chunks.reshape(orig_shape[:-1] + (chunks.shape[-1],)).unbind(0) + expert_outputs.extend(chunks) + + for i, expert_output in enumerate(expert_outputs): + expert_outputs[i] = expert_output.to(current_device) + expert_output = torch.stack(expert_outputs, dim=1) + return expert_output + + def moe_gate_dispatch( + self, + x: torch.Tensor, # [S, H] float16 / float32 / bfloat16 + gate_logits: torch.Tensor, # [S, E] float32 + k: int, + capacity: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """dispatch input to experts based on gate logits""" + + S, H = x.shape + E = gate_logits.shape[1] + device = x.device + topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1) # [S, k] + combine_weights = topk_prob # [S, k] + expert_id = topk_idx # [S, k] + y = x.new_zeros((E, capacity, H)) # [E, C, H] + scatter_index = x.new_full((k, S), -1, dtype=torch.int32) # [k, S] + # per-expert slot counters + slot_counter = torch.zeros(E, dtype=torch.int32, device=device) + + for tok in range(S): + for route in range(k): + e = expert_id[tok, route].item() + slot = slot_counter[e].item() + if slot >= capacity: # expert is full -> drop + combine_weights[tok, route] = 0.0 + continue + # record mapping & dispatch activation + scatter_index[route, tok] = e * capacity + slot + y[e, slot] = x[tok] + slot_counter[e] += 1 + + expert_offset = torch.cumsum(slot_counter, 0, dtype=torch.int64) + + return y, combine_weights, scatter_index, expert_offset, expert_id + + def gate_and_dispatch(self, input, token_type_ids=None, is_multimodel=True): + """ + Calculate gate and dispatch inputs. + + Args: + input: Input tensor of shape [seq, dim] + + Returns: + tuple: (dispatched_input, combine_weights, dispatch_mask, + scatter_index, router_loss, gate_logits, gate_prob) + """ + d_model = input.shape[1] + if isinstance(self.gate, (TopKGate)): + capacity = self.gate.get_capacity( + input.shape[0], is_multimodel=is_multimodel + ) + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + gate_logits = self.gate( + input, token_type_ids=token_type_ids, is_multimodel=is_multimodel + ) + prob = self.gate.act(gate_logits) + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = self.moe_gate_dispatch(input, prob, k=self.k, capacity=capacity) + dispatch_mask = torch.diff(F.pad(dispatch_mask, (1, 0))) + + scatter_index.detach() + dispatch_mask.detach() + + scatter_index = scatter_index.transpose(0, 1) # [k, s] -> [s, k] + combine_weights = combine_weights_unnorm / torch.clamp( + combine_weights_unnorm.sum(dim=-1, keepdim=True), min=1e-12 + ) + combine_weights = combine_weights.to(dtype=dispatched_input.dtype) + + else: + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + router_loss, + gate_logits, + ) = self.gate( + input, + ) + prob = None + dispatched_input = dispatching( + input, + dispatch_mask, + scatter_index, + num_experts=self.world_size * self.num_local_experts, + capacity=capacity, + ) + + dispatched_input = dispatched_input.reshape( + [self.world_size * self.num_local_experts, capacity, d_model] + ) + + dispatch_mask = dispatch_mask.detach() + scatter_index = scatter_index.detach() + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + None, + gate_logits, + prob, + ) + + def combine_expert_output(self, expert_output, combine_weights, scatter_index): + """ + Combine expert outputs using combination weights. + + Args: + expert_output: Expert outputs [num_experts, capacity, dim] + combine_weights: Combination weights + scatter_index: Scatter indices + + Returns: + Tensor: Combined output [seqlen, dim] + """ + expert_output = expert_output.reshape( + [-1, expert_output.shape[-1]] + ) # [e*1,c,m] + + combined_output = combining(expert_output, combine_weights, scatter_index) + + if self.output_postprocess is not None: + combined_output = self.output_postprocess(combined_output) + + return combined_output + + def forward( + self, + input: torch.Tensor, + token_type_ids=None, + is_multimodel=True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass through MoE layer. + + Args: + input: Input tensor of shape [s, d] + + Returns: + tuple: (output, combine_weights, router_loss, gate_logits) + """ + if input.dim() == 3: + orig_shape = input.shape + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + assert ( + input.dim() == 2 + ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + if token_type_ids is not None: + token_type_ids = token_type_ids.clone()[:, :-1] + + assert self.gate is not None + + gate_input = input + + ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + gate_prob, + ) = self.gate_and_dispatch( + gate_input, token_type_ids, is_multimodel=is_multimodel + ) + + if self.shared_experts is not None: + shared_out = self.shared_experts(input) + + expert_out = self.forward_experts(dispatched_input) + + combined_output = self.combine_expert_output( + expert_out, combine_weights, scatter_index + ) + + if self.shared_experts is not None: + combined_output += shared_out + + if orig_shape: + combined_output = combined_output.clone().reshape( + orig_shape[:-1] + (combined_output.shape[-1],) + ) + return combined_output, combine_weights, None, gate_logits + + +class MOEAllGatherLayerV2(MOELayer): + """ + MoE Layer with allgather implement. + """ + + def __init__( + self, + gate: nn.Module, + experts: List[nn.Module], + layer_idx, + shared_experts: Optional[List[nn.Module]] = None, + group=None, + recompute=False, + k=2, + enable_reverse_token_drop=False, + all_to_all_dropout=0, + group_experts=False, + use_expert_out_alltoall=True, # + use_expert_alltoall_overlap=False, + use_padding=True, + dense_token_type=3, # considerd as dense tokens (no moe) + moe_statics=None, + moe_num_experts=None, + ): + super().__init__( + gate, + experts, + layer_idx, + shared_experts, + group, + recompute, + k, + all_to_all_dropout, + group_experts, + moe_statics, + moe_num_experts, + ) + self.enable_reverse_token_drop = enable_reverse_token_drop + self.is_allgather_moe_layer = True + self.use_padding = use_padding + + self.send_rank = None + self.local_expert_id = None + self.dense_experts = None + self.dense_token_type = dense_token_type + self.capacity_tensor = None + logger.info( + f"uisng MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, " # false + f"use_padding={use_padding}, use_expert_alltoall_overlap={use_expert_alltoall_overlap} " # true false + f"enable_reverse_token_drop={self.enable_reverse_token_drop}" # false + ) + self.two = torch.tensor(2).to(dtype=torch.float32) + self.zero = torch.tensor(0).to(dtype=torch.float32) + + def forward( + self, + input: torch.Tensor, + token_type_ids=None, + use_dense_expert=False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Implements forward pass for Mixture-of-Experts (MoE) layer with distributed communication. + + Core Functionality: + - Processes input through gating network to determine expert assignments + - Combines expert outputs and calculates routing loss + + Key Features: + 1. Supports both dense and sparse expert computation modes + 2. Implements fused gating and dispatch for performance optimization + 3. Handles sequence length padding/unpadding for irregular inputs + 4. Enables communication-computation overlap through asynchronous operations + + Args: + input (Tensor): Input tensor of shape [seq_len, hidden_dim] + token_type_ids: Optional segmentation markers for heterogeneous inputs + use_dense_expert: Flag to enable dense expert computation bypass + + Returns: + tuple: ( + combined_output: Aggregated expert outputs [seq_len, hidden_dim], + combine_weights: Expert combination coefficients, + ) + """ + use_fuse = isinstance(self.gate, (TopKGate)) + assert use_fuse + if input.ndim == 3: + orig_shape = input.shape + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + + assert ( + len(input.shape) == 2 + ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + dispatch_token_type_ids = None + global_dense_expert_mask = None + if token_type_ids is not None: + token_type_ids = token_type_ids[:, :-1].reshape([-1]) + dispatch_token_type_ids = token_type_ids + if use_dense_expert: + global_dense_expert_mask = ( + dispatch_token_type_ids == self.dense_token_type + ) + + assert self.gate is not None + + ( + dispatched_input, + global_hidden_states, + local_combine_weights, + expert_num_global_no_token_drop, + expert_num_global, + expert_num_global_list, + local_scatter_index, + scatter_index_rev, + router_loss, + (gate_logits, gate_prob), + (gate_logits_mm, gate_prob_mm), + expert_num_local, + ) = self.fused_gate_and_dispatch( + input, token_type_ids, global_dense_expert_mask + ) + + seqlen_this_mp = input.shape[0] + if len(scatter_index_rev): + recv_rank_local = scatter_index_rev // seqlen_this_mp + else: + recv_rank_local = scatter_index_rev + + if self.send_rank is None: + capacity = self.gate.get_capacity(input.shape[0]) + self.send_rank = ( + torch.arange(1) + .repeat_interleave(capacity * self.num_local_experts) + .to(torch.int32) # cap + ) + self.local_expert_id = ( + torch.arange(self.num_local_experts) + .repeat_interleave(capacity) + .repeat(1) + .to(self.send_rank.dtype) + ) + send_rank = self.send_rank + local_expert_id = self.local_expert_id + + expert_outs = self.forward_experts(*dispatched_input) + for e in expert_outs: + if e is not None: + current_device = e.device + break + expert_outs = torch.cat( + [e.to(current_device) for e in expert_outs if e is not None], dim=0 + ) # [e*c,m] + + # global -> local + combined_output = self.combine_expert_output( + expert_outs, local_combine_weights, local_scatter_index + ) + + if self.shared_experts is not None: + shared_out = self.shared_experts(input).to(combined_output.device) + combined_output += shared_out + + if orig_shape: + combined_output = combined_output.reshape( + *orig_shape[:-1], combined_output.shape[-1] + ) + + return combined_output, local_combine_weights, None, gate_logits + + def _expand_modality_expert_id( + self, + expert_id: torch.Tensor, # (seqlen, k) + seqlen: int, + k: int, + num_expert_per_modality: int, + group_size: int, + modality_offset: int, + is_group_expert: bool, + ) -> torch.Tensor: + """ + expert_id: tensor of shape (seqlen, k), containing expert ids + Returns: tensor of same shape, with updated expert ids + """ + device = expert_id.device + expert_id = expert_id.clone() + + if is_group_expert: + # idx % k * group_size + offsets = (torch.arange(k, device=device) * group_size).view( + 1, k + ) # shape (1, k) + expert_id += offsets + + if num_expert_per_modality <= 0: + return expert_id + + # Compute rank and local expert id + rank = expert_id // num_expert_per_modality + expert_id_in_rank = expert_id % num_expert_per_modality + + # Compute new expert id with modality-aware adjustment + expert_id_out = ( + rank * (num_expert_per_modality * 2) # 2 modalities assumed + + expert_id_in_rank + + modality_offset * num_expert_per_modality + ) + + return expert_id_out + + def expand_modality_expert_id( + self, + expert_id, + num_expert_per_modality, + group_size, + modality_offset, + is_group_expert, + ): + """expand expert id for modality aware moe layer""" + seq_len, k = expert_id.shape + + return self._expand_modality_expert_id( + expert_id, + seq_len, + k, + num_expert_per_modality, + group_size, + modality_offset, + is_group_expert, + ) + + def fused_gate_logits_process_fused( + self, gate_logits_lm, gate_logits_mm=None, token_type_ids=None + ): + """Process gating logits for expert selection in Mixture-of-Experts (MoE) layers. + + Core Functionality: + - Transforms raw gating logits into expert selection weights and IDs + - Supports both grouped and standard expert selection modes + - Handles bias correction for improved expert load balancing + + Args: + gate_logits_lm (Tensor): Raw gating scores of shape [batch_size, total_experts] + + Returns: + tuple: ( + lm_weight_and_expert_id: Combined tensor containing selection weights + and expert IDs [batch_size, 2*top_k], + prob_flat: Flattened expert probabilities [batch_size, total_experts] + ) + """ + top_k = self.k + num_expert_per_rank_per_modality = gate_logits_lm.shape[-1] + group_size = gate_logits_lm.shape[-1] // top_k + if self.group_experts: + assert not self.use_correction_bias + gate_logits_lm = gate_logits_lm.reshape( + [gate_logits_lm.shape[0], top_k, -1] + ) + prob_lm = self.gate.act(gate_logits_lm) + prob_lm_ = prob_lm + weight_lm, expert_id_lm = prob_lm_.topk(k=1, dim=-1) + weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) + group_size = gate_logits_lm.shape[-1] + expert_id_lm = expert_id_lm.squeeze(-1) + else: + prob_lm = self.gate.act(gate_logits_lm) + if self.use_correction_bias: + prob_lm_ = prob_lm + self.moe_statics.e_score_correction_bias[ + 0 + ].detach().to(prob_lm.device) + else: + prob_lm_ = prob_lm + weight_lm, expert_id_lm = prob_lm_.topk(k=top_k, dim=-1) + + if self.use_correction_bias: + batch_idx = ( + torch.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + ) + weight_lm = prob_lm[batch_idx, expert_id_lm] # use correct bias + + expert_id_lm = self.expand_modality_expert_id( + expert_id_lm, + num_expert_per_modality=( + num_expert_per_rank_per_modality if token_type_ids is not None else 0 + ), + group_size=group_size, + modality_offset=0, + is_group_expert=self.group_experts, + ) + expert_id_lm = expert_id_lm.reshape(weight_lm.shape) + lm_weight_and_expert_id = torch.cat( + [weight_lm, expert_id_lm.to(torch.float32)], -1 + ) + + if token_type_ids is None or gate_logits_mm is None: + return ( + lm_weight_and_expert_id, + prob_lm.reshape([prob_lm.shape[0], -1]), + None, + ) + + prob_mm = self.gate.act(gate_logits_mm) + if self.use_correction_bias: + prob_mm_ = prob_mm + self.moe_statics.e_score_correction_bias[ + 1 + ].detach().to(prob_lm.device) + else: + prob_mm_ = prob_mm + weight_mm, expert_id_mm = prob_mm_.topk(k=top_k, dim=-1) + if self.use_correction_bias: + batch_idx = ( + torch.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) + ) + weight_mm = prob_mm[batch_idx, expert_id_mm] # use correct bias + + expert_id_mm = self.expand_modality_expert_id( + expert_id_mm, + num_expert_per_modality=num_expert_per_rank_per_modality, + group_size=group_size, + modality_offset=1, + is_group_expert=False, + ) + expert_id_mm = expert_id_mm.reshape(weight_mm.shape) + mm_weight_and_expert_id = torch.cat( + [weight_mm, expert_id_mm.to(torch.float32)], -1 + ) + weight_and_expert = torch.where( + (token_type_ids == 0).unsqueeze(-1), + lm_weight_and_expert_id.to(token_type_ids.device), + mm_weight_and_expert_id.to(token_type_ids.device), + ) + return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm + + def moe_gate_dispatch_partial_nosoftmaxtopk( + self, + x, + combine_weights, + expert_id, + k, + num_experts, + ): + """ + MoE Gate Dispatch kernel + """ + device = x.device + dtype = x.dtype + num_rows, hidden_size = x.shape + k = expert_id.shape[1] + expert_ids_flat = expert_id.reshape(-1) # [num_rows * k] + combine_weights_flat = combine_weights.reshape(-1) # [num_rows * k] + + expanded_token_ids = torch.arange(num_rows * k, device=device) # [num_rows * k] + + sorted_expert_ids, sorted_indices = torch.sort(expert_ids_flat, stable=True) + sorted_indices = sorted_indices.to(expanded_token_ids.device) + + sorted_expanded_token_ids = expanded_token_ids[sorted_indices] + + expert_nums_local = torch.zeros(num_experts, dtype=torch.int64, device=device) + + for expert_idx in range(num_experts): + count = (sorted_expert_ids == expert_idx).sum().item() + expert_nums_local[expert_idx] = count + + total_dispatched_tokens = torch.cumsum(expert_nums_local, dim=0)[-1].item() + + y = x[sorted_indices // k] # [total_dispatched_tokens, hidden_size] + + scatter_index = torch.full((k, num_rows), -1, dtype=torch.int32, device=device) + + for i, (expanded_idx, sorted_pos) in enumerate( + zip(sorted_expanded_token_ids, range(total_dispatched_tokens)) + ): + token_idx = expanded_idx // k + k_idx = expanded_idx % k + scatter_index[k_idx, token_idx] = sorted_pos + + scatter_index_rev = sorted_indices // k + + combine_weights_out = combine_weights.clone() + + return ( + y, # [total_dispatched_tokens, hidden_size] + combine_weights_out, # [num_rows, k] + scatter_index, # [k, num_rows] + scatter_index_rev, # [total_dispatched_tokens] + expert_nums_local, # [num_experts] + expert_nums_local, # [num_experts] + ) + + def fused_gate_and_dispatch( + self, input, token_type_ids=None, global_dense_expert_mask=None + ): + """Implements fused expert gating and token dispatch logic for Mixture-of-Experts (MoE) layers. + + Core Functionality: + - Computes expert selection probabilities and routing weights + - Performs distributed token-to-expert assignment + - Handles communication and synchronization in model-parallel environments + + Args: + input (Tensor): Input tensor of shape [seq_len, hidden_dim] + + Returns: + tuple: ( + dispatched_input: Expert-assigned tokens [num_experts, capacity, hidden_dim], + global_hidden_states: Full sequence representations, + local_combine_weights: Local expert combination weights, + expert_num_global_notrunc: Global expert token counts (without capacity truncation), + expert_num_global: Actual expert token counts, + expert_num_global_list: Per-expert token counts, + local_scatter_index: Local token reorganization indices, + scatter_index_rev: Reverse scattering indices, + router_loss: Calculated routing loss, + gate_outputs: Raw gating network outputs, + expert_num_local: Local expert utilization counts + ) + """ + seqlen, d_model = input.shape + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + + router_loss = torch.zeros([1], dtype=torch.float32) + top_k = self.k + + def build_weights_and_expert_id(input): + nonlocal token_type_ids, args + logits = self.gate(input, *args, transform_weight=False) + if self.config.multimodel_experts: + gate_logits_lm, gate_logits_mm = logits.chunk(2, dim=-1) + else: + gate_logits_lm, gate_logits_mm = logits, None + + weigth_and_expert, gate_prob_lm, gate_prob_mm = ( + self.fused_gate_logits_process_fused( + gate_logits_lm, + gate_logits_mm, + token_type_ids if global_dense_expert_mask is None else None, + ) + ) + return ( + weigth_and_expert, + gate_logits_lm, + gate_logits_mm, + gate_prob_lm, + gate_prob_mm, + ) + + capacity = self.gate.get_capacity(input.shape[0]) * self.world_size + global_hidden_states = input + ( + combine_weights_and_expert_id, + gate_logits_lm, + gate_logits_mm, + gate_prob_lm, + gate_prob_mm, + ) = build_weights_and_expert_id(input) + + combine_weights_unnorm, expert_id = combine_weights_and_expert_id.chunk( + 2, dim=-1 + ) + expert_id = expert_id.to(torch.int32) + num_experts = ( + sum(self.config.moe_num_experts) + if isinstance(self.config.moe_num_experts, (tuple, list)) + else self.config.moe_num_experts + ) + if global_dense_expert_mask is not None: + combine_weights_unnorm[global_dense_expert_mask] = 0.0 + expert_id[global_dense_expert_mask] = num_experts + num_experts += 1 + + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, # input -> dispatched_input + scatter_index_rev, # dispatch-input -> input + expert_num_global, + expert_num_local, + ) = self.moe_gate_dispatch_partial_nosoftmaxtopk( + global_hidden_states, + combine_weights_unnorm, + expert_id, + top_k, + num_experts, + ) + + if self.use_correction_bias: + if self.gate.config.multimodel_experts: + # MLLM + for i in range(len(self.moe_statics.expert_usage)): + self.moe_statics.expert_usage[i] += ( + expert_num_local[self.gate.experts_type_mask[i]] + .detach() + .to(self.moe_statics.expert_usage.device) + ) + else: + # LLM + self.moe_statics.expert_usage[0] += expert_num_local.detach().to( + self.moe_statics.expert_usage.device + ) + + # When use unpad , `moe_ops_partial` output likes `scatter_index_rev==[]`. + if scatter_index_rev.ndim == 0: + assert not self.use_padding + scatter_index_rev = torch.empty([0], dtype=scatter_index_rev.dtype) + + expert_num_global_notrunc = expert_num_global + self.capacity_tensor = torch.tensor(capacity).to(dtype=expert_num_global.dtype) + expert_num_global = torch.minimum(expert_num_global, self.capacity_tensor) + + if global_dense_expert_mask is not None: + expert_num_global = expert_num_global[:-1] + expert_num_local = expert_num_local[:-1] + expert_num_global_notrunc = expert_num_global_notrunc[:-1] + + scatter_index = scatter_index.transpose(1, 0) # [k,s] ->[s,k] + scatter_index = scatter_index.to(combine_weights_unnorm.device) + + last_local_expert = 0 + expert_offset_global = expert_num_global.cumsum(-1) + + expert_num_global_list = expert_num_global + if self.use_padding: + offset = last_local_expert * capacity + else: + offset = 0 + local_combine_weights_unnorm = combine_weights_unnorm.contiguous() + local_scatter_index = torch.where( + combine_weights_unnorm > 0.0, + scatter_index + offset, + scatter_index, + ) + if self.gate.norm_gate_logits: + local_combine_weights = local_combine_weights_unnorm / torch.clip( + local_combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + local_combine_weights = local_combine_weights_unnorm + local_combine_weights = local_combine_weights.to(dispatched_input.dtype) + if self.use_padding: + dispatched_input = dispatched_input.reshape( + [self.num_local_experts, -1, d_model] + ) + dispatched_input = dispatched_input.unbind(0) + else: + s = 0 + e = self.num_local_experts + expert_num_local = expert_num_local.tolist()[s:e] + expert_num_local_valid = [i for i in expert_num_local if i > 0] + valid_pos = [j for j, i in enumerate(expert_num_local) if i > 0] + if expert_num_local_valid: + dispatched_input_list = dispatched_input.split(expert_num_local_valid) + dispatched_input = [None] * len(expert_num_local) + for p, t in zip(valid_pos, dispatched_input_list): + dispatched_input[p] = t + else: + dispatched_input = [dispatched_input] + ( + [None] * (len(expert_num_local) - 1) + ) + + expert_num_global_list = expert_num_global_list.tolist() + + return ( + dispatched_input, + global_hidden_states, + local_combine_weights, + expert_num_global_notrunc, # for auxloss calculation. + expert_num_global, + expert_num_global_list, + local_scatter_index, + scatter_index_rev, + router_loss, + (gate_logits_lm, gate_prob_lm), + (gate_logits_mm, gate_prob_mm), + expert_num_local, + ) + + def forward_experts(self, *dispatched_input): + """Execute expert model computations in sequence for Mixture-of-Experts (MoE) layer. + + Core Functionality: + - Distributes dispatched tokens to local expert models + - Handles empty expert inputs with zero-initialized fallback + - Maintains gradient flow for expert outputs + - Aggregates outputs from all active experts + + Args: + *dispatched_input: Variable-length expert-specific input tensors + + Returns: + list: Expert output tensors (None for inactive experts) + + Implementation Details: + 1. Processes valid expert inputs through corresponding expert models + 2. Generates dummy inputs for inactive experts to preserve model structure + 3. Aggregates dummy outputs to first active expert to maintain gradient flow + """ + expert_outputs = [] + assert isinstance(self.experts, nn.ModuleList), type(self.experts) + + no_tokens_expert_outputs = [] + true_experts = self.experts[ + self.rank + * self.num_local_experts : (self.rank + 1) + * self.num_local_experts + ] + for iexpert, chunk in enumerate(dispatched_input): + if chunk is None: + expert_outputs.append(None) + continue + + expert_out = true_experts[iexpert](chunk.contiguous()) + expert_outputs.append(expert_out) + + if len(no_tokens_expert_outputs) > 0: + first_has_tokens_idx = 0 + for idx, expert_out in enumerate(expert_outputs): + if expert_out is not None: + first_has_tokens_idx = idx + break + for idx, expert_out in enumerate(no_tokens_expert_outputs): + expert_outputs[first_has_tokens_idx] += expert_out + + return expert_outputs + + +class Ernie4_5_DecoderLayer(nn.Module): + """A single transformer decoder layer in ERNIE-MoE model. + + Contains self-attention and feed-forward components with optional MoE (Mixture of Experts) + support, residual connections, and layer normalization. + """ + + _keep_in_fp32_modules = ["mlp.gate", "e_score_correction_bias"] + + def __init__(self, config, layer_idx): + """Initialize the decoder layer. + + Args: + config (Ernie4_5_MoEConfig): Model configuration. + layer_idx (int): Index of this layer in the transformer stack + """ + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.config = config + self.use_moe = config.use_moe + self.self_attn = Ernie4_5_Attention(config, layer_idx) + + moe_layer_start_index = ( + min(config.moe_layer_start_index) + if isinstance(config.moe_layer_start_index, (tuple, list)) + else config.moe_layer_start_index + ) + moe_layer_end_index = ( + max(config.moe_layer_end_index) + if isinstance(config.moe_layer_end_index, (tuple, list)) + else config.moe_layer_end_index + ) + + if ( + self.use_moe + and ((layer_idx + 1) % config.moe_layer_interval == 0) + and layer_idx >= moe_layer_start_index # 3 + and layer_idx <= moe_layer_end_index # 53 + ): + gate, experts, lm_gate, lm_experts, moe_statics = ( + self._init_gate_and_experts(layer_idx) + ) + shared_experts = ( + self._init_shared_experts() + if hasattr(config, "moe_num_shared_experts") + else None + ) + + dense_experts = None + moe_cls = MOELayer + if config.moe_multimodal_dispatch_use_allgather: # v2 + logger.info("Enable MOEAllGatherLayerV2!") + moe_cls = partial( + MOEAllGatherLayerV2, + use_expert_out_alltoall="alltoall" + in config.moe_multimodal_dispatch_use_allgather, # false + use_padding=False, + enable_reverse_token_drop=config.moe_reverse_token_drop, # false + dense_token_type=config.moe_dense_experts_token_type_id, # 3 + ) + else: + assert ( + dense_experts is None + ), "only `MOEAllGatherLayerV2` can process dense experts" + + self.mlp = moe_cls( + gate=gate, + experts=experts, + layer_idx=layer_idx, + shared_experts=shared_experts, + group=config.moe_group, + recompute=False, + k=config.moe_k, + all_to_all_dropout=config.moe_all_to_all_dropout, + group_experts=False, + moe_statics=moe_statics, + moe_num_experts=config.moe_num_experts, + ) + + _mlp_text = MOELayer( + gate=lm_gate, + experts=lm_experts, + layer_idx=layer_idx, + shared_experts=shared_experts, + group=config.moe_group, + recompute=False, + k=config.moe_k, + all_to_all_dropout=config.moe_all_to_all_dropout, + group_experts=False, + moe_statics=moe_statics, + moe_num_experts=config.moe_num_experts, + ) + self.mlp_text = ( + lambda: _mlp_text + ) # This lambda prevents the text parameter from being scanned into the state-dict + else: + self.mlp = Ernie4_5_MLP(config) + + Norm = RMSNorm + + self.input_layernorm = Norm(config) + self.post_attention_layernorm = Norm(config) + + self.residual_add1 = FusedDropoutImpl( + config.hidden_dropout_prob, mode="upscale_in_train" + ) + self.residual_add2 = FusedDropoutImpl( + config.hidden_dropout_prob, mode="upscale_in_train" + ) + + def _init_shared_experts(self): + """init shared experts + + Returns: + _type_: _description_ + """ + cfg = deepcopy(self.config) + if cfg.moe_num_shared_experts > 0: + if cfg.moe_intermediate_size: + inter_size = ( + next(iter(cfg.moe_intermediate_size)) + if isinstance(cfg.moe_intermediate_size, (tuple, list)) + else cfg.moe_intermediate_size + ) + cfg.intermediate_size = inter_size * cfg.moe_num_shared_experts + else: + cfg.intermediate_size = ( + cfg.intermediate_size * cfg.moe_num_shared_experts + ) + cfg.disable_ffn_model_parallel = False # split shared epxert + shared_experts = Ernie4_5_MoeMLP(cfg, True) + else: + shared_experts = None + return shared_experts + + def _init_gate_and_experts(self, layer_idx): + """Initialize MoE gate and expert networks. + + Args: + layer_idx (int): Current layer index + + Returns: + Tuple: Contains: + - gate: MoE routing gate + - experts: List of expert networks + - moe_statics: Optional statistics tracker + """ + cfg = deepcopy(self.config) + fc_cls = Ernie4_5_MoeMLP + if cfg.moe_intermediate_size: + if isinstance(cfg.moe_intermediate_size, (tuple, list)): + assert isinstance(cfg.moe_num_experts, (tuple, list)) and len( + cfg.moe_num_experts + ) == len(cfg.moe_intermediate_size) + fc = [] + for _i, (num_experts, intermediate_size) in enumerate( + zip(cfg.moe_num_experts, cfg.moe_intermediate_size) + ): + ex_cfg = deepcopy(cfg) + ex_cfg.intermediate_size = intermediate_size + cur_modality_start_layer_idx = ( + cfg.moe_layer_start_index[_i] + if isinstance(cfg.moe_layer_start_index, (tuple, list)) + else cfg.moe_layer_start_index + ) + cur_modality_end_layer_idx = ( + cfg.moe_layer_end_index[_i] + if isinstance(cfg.moe_layer_end_index, (tuple, list)) + else cfg.moe_layer_end_index + ) + if ( + layer_idx >= cur_modality_start_layer_idx + and layer_idx <= cur_modality_end_layer_idx + ): + if _i == 1: + with UniqueNameGuard(f"mm_expert_{layer_idx}_") as guard: + fc.append((num_experts, fc_cls(ex_cfg))) + else: + fc.append((num_experts, fc_cls(ex_cfg))) + else: + logger.info( + f"moe multimodal experts use Identity layer_idx: {layer_idx}" + ) + fc.append((num_experts, nn.Identity())) + else: + cfg.intermediate_size = cfg.moe_intermediate_size + fc = [(cfg.moe_num_experts, fc_cls(cfg, layer_idx))] + else: + fc = [(cfg.moe_num_experts, fc_cls(cfg, layer_idx))] + if cfg.multimodel_experts: + gate, experts, lm_gate, lm_experts = get_gate(self.config, fc, layer_idx) + else: + gate, experts = get_gate(self.config, fc, layer_idx) + lm_gate, lm_experts = None, None + + # for AuxLoss Free Router: + if cfg.moe_use_aux_free: + moe_statics = MoEStatics(cfg, layer_idx) + else: + moe_statics = None + return gate, experts, lm_gate, lm_experts, moe_statics + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attn_mask_start_row_indices: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_gate_logits=True, # PP model should not output gate logits, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + """Forward pass through the decoder layer. + + Args: + hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size] + attention_mask (Optional[torch.Tensor]): Attention mask tensor + attn_mask_start_row_indices (Optional[torch.Tensor]): Indices for variable length attention + position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings + output_attentions (Optional[bool]): Whether to return attention weights + past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states + use_cache (Optional[bool]): Whether to cache key/value states + output_gate_logits (bool): Whether to return MoE gate logits + + Returns: + Union: Various output combinations depending on arguments: + - Base case: Hidden states tensor + - With attention: Tuple of (hidden_states, attention_weights) + - With cache: Tuple of (hidden_states, cached_key_value) + - With MoE: May include gate logits in output tuple + """ + residual = hidden_states + + if token_type_ids is not None: + is_multimodel_token = token_type_ids.any() + has_dense_experts_token = ( + token_type_ids == self.config.moe_dense_experts_token_type_id + ).any() + is_multimodel_token_cpu = is_multimodel_token.cpu() + has_dense_experts_token_cpu = has_dense_experts_token.cpu() + else: + is_multimodel_token_cpu = None + has_dense_experts_token_cpu = None + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + (hidden_states, self_attn_weights, present_key_value, *router_loss_attn) = ( + self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + attn_mask_start_row_indices=attn_mask_start_row_indices, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=use_cache, + token_type_ids=token_type_ids, + ) + ) + hidden_states = self.residual_add1(hidden_states, residual) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if isinstance(self.mlp, MOELayer): + if is_multimodel_token_cpu: + hidden_states, _, router_loss, gate_logits = self.mlp( + hidden_states, token_type_ids + ) + else: + hidden_states, _, router_loss, gate_logits = self.mlp_text()( + hidden_states, None, is_multimodel=False + ) + else: + hidden_states = self.mlp(hidden_states) + gate_logits, router_loss = None, None + + hidden_states = self.residual_add2(hidden_states, residual) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if self.use_moe: + # Non-empty only if `use_moe` + if router_loss_attn: + router_loss_attn = router_loss_attn[0] + router_loss = router_loss + router_loss_attn + + if output_gate_logits: + outputs += (gate_logits,) + + # remove empty tuple for pipeline parallel + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class Ernie4_5_PretrainedModel(PreTrainedModel): + """Base class for ERNIE pretrained models.""" + + config_class = Ernie4_5_MoEConfig + base_model_prefix = "ernie" + _no_split_modules = ["Ernie4_5_DecoderLayer"] + # _keep_in_fp32_modules = ["mlp.gate", "e_score_correction_bias"] + + +class Ernie4_5_Model(Ernie4_5_PretrainedModel): + """The core ERNIE transformer model with MoE (Mixture of Experts) support.""" + + def __init__(self, config: Ernie4_5_MoEConfig): + """Initialize the ERNIE model architecture. + + Args: + config (Ernie4_5_MoEConfig): Model configuration. + """ + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.config = config + + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + self.layers = nn.ModuleList( + [Ernie4_5_DecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + Norm = RMSNorm + self.norm = Norm(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + """Get the input embedding layer. + + Returns: + nn.Embedding: The embedding layer for input tokens + """ + return self.embed_tokens + + def set_input_embeddings(self, value): + """Set new input embeddings. + + Args: + value (nn.Embedding): New embedding layer to use + """ + self.embed_tokens = value + + def forward( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + attention_mask=None, + attn_mask_start_row_indices=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + ): + """Forward pass through the ERNIE model. + + Args: + input_ids (Optional[torch.Tensor]): Input token IDs + position_ids (Optional[torch.Tensor]): Position indices + attention_mask (Optional[torch.Tensor]): Attention mask + attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices + inputs_embeds (Optional[torch.Tensor]): Precomputed embeddings + use_cache (Optional[bool]): Whether to cache key/value states + past_key_values (Optional[Tuple[Tuple[torch.Tensor]]]): Cached key/value states + output_attentions (Optional[bool]): Whether to output attention weights + output_hidden_states (Optional[bool]): Whether to output all hidden states + return_dict (Optional[bool]): Whether to return dict or tuple + + Returns: + Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + Various outputs depending on configuration, including: + - last_hidden_state: Final layer hidden states + - past_key_values: Cached key/value states if use_cache=True + - hidden_states: All hidden states if output_hidden_states=True + - attentions: Attention weights if output_attentions=True + - router_loss: MoE router loss if use_moe=True + - gate_logits: MoE gate logits if use_moe=True + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + _, seq_length = input_ids.shape + elif inputs_embeds is not None: + _, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = past_key_values[0][0].shape[1] + seq_length_with_past += cache_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + if getattr(self.config, "use_moe", False): + all_router_loss = torch.tensor(0.0).to(device=inputs_embeds.device) + else: + all_router_loss = None + all_gate_logits = () + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + attn_mask_start_row_indices, + position_ids, + token_type_ids, + output_attentions, + past_key_value, + use_cache, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + if self.config.use_moe: + layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1] + all_gate_logits = all_gate_logits + (gate_logits,) + + if past_key_value is not None: + hidden_states = hidden_states[:, -1:, :] + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_router_loss, + all_gate_logits, + ] + if v is not None + ) + + # assert all_router_loss is None, f'moe not support `return-dict`' + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + router_loss=all_router_loss, + gate_logits=all_gate_logits, + ) + + +def parallel_matmul( + x, + y, + bias=None, + transpose_y=False, +): + """ + Performs parallel matrix multiplication with tensor model parallelism support. + + Args: + x (torch.Tensor): Input tensor with shape [batch_size, seq_len, hidden_size] + y (Union[torch.Tensor, EagerParamBase]): Weight matrix which can be: + - Regular tensor + - Distributed parameter in tensor parallel mode + bias (Optional[torch.Tensor]): Optional bias tensor + transpose_y (bool): Whether to transpose the 'y' matrix before multiplication + # tensor_parallel_degree (int): Degree of tensor model parallelism (default: 1) + # tensor_parallel_output (bool): Whether to keep output in tensor parallel format + or gather across devices (default: True) + fuse_linear (bool): Whether to use fused linear operation for optimization + + Returns: + torch.Tensor + + Raises: + AssertionError: If tensor parallel is enabled but weight is not distributed + AttributeError: If called without distributed.launch context + """ + if transpose_y: + logits = torch.matmul(x, y.T) + else: + logits = torch.matmul(x, y) + if bias is not None: + logits += bias + return logits + + +def calc_lm_head_logits( + config, hidden_states, weight, bias, tensor_parallel_output=None, training=True +): + """ + Calculate language model head logits with support for various parallelization strategies. + + This is the core function that computes the final output logits for a language model, + handling sequence parallelism and tensor parallelism configurations. + + Args: + config (Ernie4_5_Config): Model configuration. + hidden_states (Tensor): Hidden states from the transformer layers + weight (Tensor): Weight matrix for the language model head + bias (Tensor): Bias vector for the language model head + tensor_parallel_output (bool, optional): Override for tensor parallel output behavior. + If None, uses config.tensor_parallel_output. + Defaults to None. + training (bool, optional): Whether in training mode. Defaults to True. + + Returns: + Tensor: The computed logits for language modeling. + """ + if tensor_parallel_output is None: + tensor_parallel_output = config.tensor_parallel_output + logits = parallel_matmul( + hidden_states, + weight, + bias=bias, + transpose_y=config.tie_word_embeddings, + ) + + return logits + + +def calc_multimodal_logits( + last_hidden_state: torch.Tensor, + lm_head_weight: torch.Tensor, + lm_head_bias: torch.Tensor, + mm_head_weight: torch.Tensor, + mm_head_bias: torch.Tensor, + token_type_ids_shifted: torch.Tensor, + config: Ernie4_5_VLMoEConfig, +): + """ + calculate logits for pure text, multimodal text, and image + Args: + last_hidden_state: The hidden of the last layer, in sequence-parallel, is in the split state. + ... + token_type_ids_shifted: # Non-sp split tensor + The token-type-ids at the label position is used to select the lm-head corresponding to each token. + Note: In the id sequence of alternating images and texts, the last text token will predict the image id, + and vice versa, so it is necessary to select the lmhead weight corresponding to the label type. + """ + # Align the type of ids with the type of label. For the last ids, assume that the token type remains unchanged. + # TODO: Pass token-type-ids from reader + assert last_hidden_state.shape[:2] == token_type_ids_shifted.shape, ( + last_hidden_state.shape, + token_type_ids_shifted.shape, + ) + parallel_matmul_tp = partial( + parallel_matmul, + ) + + if mm_head_weight is None: + if config.use_recompute_loss_fn: + return last_hidden_state, None, None + score_text = parallel_matmul_tp(last_hidden_state, lm_head_weight, lm_head_bias) + return score_text, None, None + + image_mask_shifted = token_type_ids_shifted == TokenType.image + text_pos_shifted = token_type_ids_shifted == TokenType.text + + if text_pos_shifted.any().item() > 0: + score_text = parallel_matmul_tp( + last_hidden_state[text_pos_shifted], lm_head_weight, lm_head_bias + ) + else: + score_text = None + + if mm_head_weight is not None and image_mask_shifted.any().item() > 0: + score_image = parallel_matmul_tp( + last_hidden_state[image_mask_shifted], mm_head_weight, mm_head_bias + ) + else: + score_image = None + + return score_text, score_image, None + + +class Ernie4_5_MoeLMHead(nn.Module): + """Language model head for ERNIE with support for tensor parallelism.""" + + def __init__(self, config): + """Initialize the language model head. + + Args: + config (Ernie4_5_Config): Model configuration containing: + - vocab_size: Size of vocabulary + - hidden_size: Dimension of hidden states + # - tensor_parallel_degree: Degree of tensor parallelism + - tie_word_embeddings: Whether to tie input/output embeddings + - weight_share_add_bias: Whether to add bias when weight sharing + - use_bias: Whether to use bias term + - use_recompute_loss_fn: Whether to defer logits computation to loss function + - use_sparse_head_and_loss_fn: Whether to use sparse head computation + """ + + super(Ernie4_5_MoeLMHead, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + if config.tie_word_embeddings: + self.weight = nn.Parameter( + torch.empty( + vocab_size, config.hidden_size, dtype=torch.get_default_dtype() + ) + ) + else: + self.weight = nn.Parameter( + torch.empty( + config.hidden_size, vocab_size, dtype=torch.get_default_dtype() + ) + ) + nn.init.xavier_uniform_(self.weight) + + logger.info( + f"output-weight:{self.weight.shape} tie_word_embeddings:{config.tie_word_embeddings}" + ) + + if config.weight_share_add_bias and config.use_bias: + self.bias = nn.Parameter( + torch.zeros(vocab_size, dtype=torch.get_default_dtype()) + ) + else: + self.bias = None + + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = ( + True if (vocab_size != config.vocab_size) else False + ) + if config.weight_share_add_bias and config.use_bias: + self.bias.is_distributed = ( + True if (vocab_size != config.vocab_size) else False + ) + + if self.weight.is_distributed: + self.weight.split_axis = 1 + if ( + config.weight_share_add_bias + and config.use_bias + and self.bias.is_distributed + ): + self.bias.split_axis = 0 + + if self.config.use_recompute_loss_fn: + logger.info( + "Using recompute_loss_fn, the calculation of logits will be moved into " + "loss_fn for memory optimization" + ) + + def forward(self, hidden_states, tensor_parallel_output=None): + """Project hidden states to vocabulary logits. + + Args: + hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + tensor_parallel_output (Optional[bool]): Whether to output parallel results. Defaults to None. + + Returns: + Union[ + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # When use_recompute_loss_fn or use_sparse_head_and_loss_fn + - hidden_states: Original input + - weight: Projection weights + - bias: Optional bias term + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], bool]: # With tensor_parallel_output + Same as above plus tensor_parallel_output flag + torch.Tensor: # Normal case + Logits tensor of shape [batch_size, seq_len, vocab_size] + ] + """ + return calc_lm_head_logits( + self.config, + hidden_states, + self.weight, + self.bias, + tensor_parallel_output, + training=self.training, + ) + + +class Ernie4_5_MoeForCausalLM(Ernie4_5_PretrainedModel, GenerationMixin): + """ERNIE Mixture of Experts (MoE) model for causal language modeling.""" + + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + """ + Initializes the ERNIE MoE model for causal language modeling. + + Args: + config (dict): Model configuration. + """ + super().__init__(config) + + # initialize-trick for big model, + # see https://github.com/bigscience-workshop/bigscience/blob/master/train/tr11-176B-ml/README.md#std-init + new_initializer_range = math.sqrt(0.3333 / config.hidden_size) + logger.info( + f"change initializer-range from {config.initializer_range} to {new_initializer_range}" + ) + config.initializer_range = new_initializer_range + self.config = config + self.model = Ernie4_5_Model(config) + self.lm_head = Ernie4_5_MoeLMHead(config) + + self.tie_weights() # maybe weight share + + def get_input_embeddings(self): + """Returns the input embeddings layer.""" + return self.model.embed_tokens + + def set_input_embeddings(self, value): + """Sets the input embeddings layer.""" + self.model.embed_tokens = value + + def get_output_embeddings(self): + """Returns the output embeddings (LM head).""" + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + """Sets the output embeddings layer.""" + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + """Sets the ERNIE decoder model.""" + self.model = decoder + + def get_decoder(self): + """Get the transformer decoder. + + Returns: + nn.Layer: The decoder module + """ + return self.model + + def prepare_attention_mask_for_generation( + self, input_ids, pad_token_id, eos_token_id + ): + """Avoid using attention_mask with flash_attn on generation.""" + if self.config.use_flash_attention: + return None + return super().prepare_attention_mask_for_generation( + input_ids, pad_token_id, eos_token_id + ) + + +class VisionMlp(nn.Module): + """VisionMLP""" + + def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.act = ACT2FN[hidden_act] + self.fc2 = nn.Linear(hidden_dim, dim) + + def forward(self, x) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor + + Returns: + torch.Tensor: VisionMLP output tensor + """ + return self.fc2(self.act(self.fc1(x))) + + +class PatchEmbed(nn.Module): + """PatchEmbed""" + + def __init__( + self, + patch_size: int = 14, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + """ + Args: + patch_size (int, optional): patch size. Defaults to 14. + in_channels (int, optional): number of channels. Defaults to 3. + embed_dim (int, optional): embedding dimension. Defaults to 1152. + """ + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + self.proj = nn.Linear( + in_channels * patch_size * patch_size, embed_dim, bias=False + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): hidden states + + Returns: + torch.Tensor: output tensor + """ + target_dtype = self.proj.weight.dtype + + hidden_states = self.proj(hidden_states.to(target_dtype)) + + return hidden_states + + +class VisionRotaryEmbedding(nn.Module): + """VisionRotaryEmbedding""" + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + """ + Args: + dim (int): the dimension of each token. + theta (float, optional): the frequency factor. Defaults to 10000.0. + """ + super().__init__() + self.inv_freq = 1.0 / theta ** ( + torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim + ) + + def forward(self, seqlen: int) -> torch.Tensor: + """ + Args: + seqlen (int): length of sequence. + + Returns: + torch.Tensor: rotary position embedding + """ + seq = torch.arange(seqlen).to(self.inv_freq.dtype) + freqs = torch.outer(input=seq, vec2=self.inv_freq) + return freqs + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) # shape is the same as x + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + """Applies Rotary Position Embedding to the input tensors. + + Args: + tensor (torch.Tensor): The input tensor. + freqs (torch.Tensor): The frequencies used for the rotation. + Returns: + output (torch.Tensor): the tensor rotated using the Rotary Position Embedding. + """ + orig_dtype = tensor.dtype + + tensor = tensor.type(dtype=torch.float32) + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).tile(1, 1, 2).unsqueeze(0).type(dtype=torch.float32) + sin = sin.unsqueeze(1).tile(1, 1, 2).unsqueeze(0).type(dtype=torch.float32) + output = tensor * cos + rotate_half(tensor) * sin + output = output.to(orig_dtype) + return output + + +class VisionAttention(nn.Module): + """VisionAttention""" + + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + self.head_dim = dim // num_heads # must added + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """forward function for vision attention""" + seq_length = hidden_states.shape[0] + qkv = ( + self.qkv(hidden_states) + .reshape([seq_length, 3, self.num_heads, -1]) + .permute(1, 0, 2, 3) + ) + q, k, v = qkv.unbind(axis=0) + + q = apply_rotary_pos_emb_vision(q.unsqueeze(dim=0), rotary_pos_emb).squeeze( + dim=0 + ) + k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze( + dim=0 + ) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + attention_mask = torch.full( + [1, seq_length, seq_length], + torch.finfo(q.dtype).min, + device=q.device, + dtype=q.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class DFNRopeVisionBlock(nn.Module): + """DFNRopeVisionBlock""" + + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + """ + Args: + config (dict): model configuration. + attn_implementation (str, optional): attention implementation. Defaults to "sdpa". + """ + super().__init__() + self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6) + self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6) + mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) + + self.attn = VisionAttention(config.embed_dim, num_heads=config.num_heads) + self.mlp = VisionMlp( + dim=config.embed_dim, + hidden_dim=mlp_hidden_dim, + hidden_act=config.hidden_act, + ) + self.config = config + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + """ + Args: + hidden_states(torch.Tensor): hidden states + cu_seqlens (torch.Tensor): cumulative sequence lengths + rotary_pos_emb: rotary position embedding + + Returns: + torch.Tensor: output tensor + """ + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class DFNRopeVisionTransformerPreTrainedModel(PreTrainedModel): + """DFNRopeVisionTransformerPreTrainedModel""" + + config_class = DFNRopeVisionTransformerConfig + _tp_plan = {} + + def __init__(self, config) -> None: + """ + Args: + config (dict): model configuration + """ + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + in_channels=config.in_channels, + embed_dim=config.embed_dim, + ) + + head_dim = config.embed_dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [DFNRopeVisionBlock(config) for _ in range(config.depth)] + ) + + assert ( + config.hidden_size == config.embed_dim + ), "in DFNRope, vit's config.hidden must be equal to config.embed_dim" + self.ln = nn.LayerNorm(config.hidden_size, eps=1e-6) + + def rot_pos_emb(self, grid_thw, num_pad=0): + """rot_pos_emb + + Args: + grid_thw (torch.Tensor): grid thw of input + + Returns: + torch.Tensor: rotary position embedding + """ + pos_ids = [] + grid_hw_array = np.array(grid_thw.cpu(), dtype=np.int64) + for t, h, w in grid_hw_array: + hpos_ids = np.arange(h).reshape([-1, 1]) + hpos_ids = np.tile(hpos_ids, (1, w)) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = np.transpose(hpos_ids, (0, 2, 1, 3)) + hpos_ids = hpos_ids.flatten() + + wpos_ids = np.arange(w).reshape([1, -1]) + wpos_ids = np.tile(wpos_ids, (h, 1)) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = np.transpose(wpos_ids, (0, 2, 1, 3)) + wpos_ids = wpos_ids.flatten() + + stacked_ids = np.stack([hpos_ids, wpos_ids], axis=-1) + tiled_ids = np.tile(stacked_ids, (t, 1)) + pos_ids.append(tiled_ids) + + pos_ids = np.concatenate(pos_ids, axis=0) + if num_pad > 0: + pos_ids = np.concatenate( + [pos_ids, np.zeros((num_pad, 2), dtype=pos_ids.dtype)] + ) + max_grid_size = np.amax(grid_hw_array[:, 1:]) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_dim=1) + return rotary_pos_emb + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0 + ) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): input tensor + grid_thw (torch.Tensor): grid thw of input + num_pad (int): number of padding tokens + + Returns: + torch.Tensor: output tensor + """ + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw, num_pad=num_pad) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + + if num_pad > 0: + cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) + cu_seqlens[-1] = cu_seqlens[-2] + num_pad + else: + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for idx, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + ) + + ret = self.ln(hidden_states) # add norm + return ret + + +class VariableResolutionResamplerModel(nn.Module): + """ + VariableResolutionResamplerModel, support variable resolution + """ + + def __init__(self, in_dim, out_dim, spatial_conv_size, temporal_conv_size, config): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.config = config + self.spatial_conv_size = spatial_conv_size + self.temporal_conv_size = temporal_conv_size + self.use_temporal_conv = config.use_temporal_conv + + # compress 2d conv(picture) to 1d + self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size + # compress 3d conv(video) to 1d + self.temporal_dim = ( + self.in_dim + * self.spatial_conv_size + * self.spatial_conv_size + * self.temporal_conv_size + ) + + # using unique name space start with "mm_resampler_" + with UniqueNameGuard("mm_resampler_") as guard: + + self.spatial_linear = nn.Sequential( + nn.Linear(self.spatial_dim, self.spatial_dim), + nn.GELU(), + nn.Linear(self.spatial_dim, self.spatial_dim), + nn.LayerNorm(self.spatial_dim, eps=1e-6), + ) + + if self.use_temporal_conv: + self.temporal_linear = nn.Sequential( + nn.Linear(self.temporal_dim, self.spatial_dim), + nn.GELU(), + nn.Linear(self.spatial_dim, self.spatial_dim), + nn.LayerNorm(self.spatial_dim, eps=1e-6), + ) + + self.mlp = nn.Linear(self.spatial_dim, self.out_dim) + + out_config = deepcopy(config) + out_config.hidden_size = out_dim + self.after_norm = RMSNorm(out_config) + + def spatial_conv_reshape(self, x, spatial_conv_size): + """ + reshape before linear to imitation conv + """ + S, C = x.shape + x = x.reshape([-1, C * (spatial_conv_size**2)]) + return x + + def forward(self, x, image_mask, token_type_ids, image_type_ids, grid_thw): + """ + x: image_features + image_mask: [B] + token_types_ids: [B] + image_type_ids: [B_image] + grid_thw: [B_image, 3] + """ + assert image_type_ids is not None + + def fwd_spatial(x): + """ + x in the shape of [S, H] + S is ordered in the following way: [ [patch_h*patch_w (row-major traversal)] * patch_time] + H is simply hidden + """ + x = self.spatial_conv_reshape(x, self.spatial_conv_size) + + x = self.spatial_linear(x) + + return x + + def fwd_placeholder(x, grid_thw, to_tensor=False): + """ + x: [S, H] + grid_thw: [S, 3] + the second dimension: [t, h, w] + """ + + grid_thw_cpu = grid_thw.cpu().numpy() + grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] + grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size**2) + + tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size**2) + batch_offset = np.empty( + tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype + ) + batch_offset[0] = 0 + batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] + + assert ( + self.temporal_conv_size == 2 + ), f"Hard Code: temporal_conv_size==2, got:{self.temporal_conv_size}" + + # TODO: support any temporal conv size + slice_offsets = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset + ): + for temp_offset in range(0, temporoal_size, 2): + slice_offsets.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + ) + ) + slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to( + x.device + ) + + slice_offsets2 = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset + ): + for temp_offset in range( + 1 if temporoal_size > 1 else 0, temporoal_size, 2 + ): + slice_offsets2.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + ) + ) + slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to( + x.device + ) + + x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) + x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) + x = torch.concat([x_timestep_1, x_timestep_2], dim=-1) + return x + + def fwd_temporal(x): + x = self.temporal_linear(x) + return x + + def fwd_mlp(x): + x = self.mlp(x) + x = self.after_norm(x) + return x + + x = fwd_spatial(x) + if self.use_temporal_conv: + x = fwd_placeholder(x, grid_thw) + x = fwd_temporal(x) + x = fwd_mlp(x) + return x + + +class Ernie4_5_MoeVLHead(Ernie4_5_MoeLMHead): + """Ernie4_5_MoeVLHead""" + + def __init__(self, config): + super().__init__(config) + self.config = config + if config.mm_vocab_size > 0: + mm_vocab_config = deepcopy(config) + mm_vocab_config.vocab_size = config.mm_vocab_size + assert mm_vocab_config.vocab_size > 0, mm_vocab_config + assert ( + mm_vocab_config.im_patch_id >= mm_vocab_config.max_text_id + ), mm_vocab_config + self.mm_head = Ernie4_5_MoeLMHead(mm_vocab_config) + else: + self.mm_head = None + + def forward(self, hidden_state, token_type_ids_labels, use_cache=False): + """ + Args: + hidden_state(torch.Tensor): hidden state + token_type_ids_labels(torch.Tensor): token ids + use_cache(bool): whether to use cache, default is False + + Returns: + logits_text(torch.Tensor): text logits + logits_image(torch.Tensor): image logits + """ + if not use_cache: + mm_head_weight = self.mm_head.weight if self.mm_head is not None else None + mm_head_bias = self.mm_head.bias if self.mm_head is not None else None + logits_text, logits_image, _ = calc_multimodal_logits( + hidden_state, + self.weight, + self.bias, + mm_head_weight, + mm_head_bias, + token_type_ids_labels, + self.config, + ) + return logits_text, logits_image, None + else: + # TODO,support lm_head decode only + return ( + parallel_matmul( + hidden_state[:, -1:, :], + self.weight, + self.bias, + transpose_y=self.config.tie_word_embeddings, + ), + None, + None, + ) + + +class Ernie4_5_VLMoeForConditionalGeneration(Ernie4_5_MoeForCausalLM): + """Ernie4_5_VLMoeForConditionalGeneration""" + + config_class = Ernie4_5_VLMoEConfig + main_input_name = "pixel_values" + _keep_in_fp16_modules = ["vision_model"] + _tp_plan = {} + + def __init__( + self, config: Ernie4_5_VLMoEConfig, vision_model=None, resampler_model=None + ): + """ + initialize Ernie4_5_VLMoeForConditionalGeneration + + Args: + config(Ernie4_5_VLMoEConfig): Model configuration. + vision_model(nn.Module): vision model + resampler_model(nn.Module): resampler model + """ + super().__init__(config) + + self.vision_model = DFNRopeVisionTransformerPreTrainedModel( + config.vision_config + ) + + self.model.resampler_model = VariableResolutionResamplerModel( + config.pixel_hidden_size, + config.hidden_size, + config.spatial_conv_size, + config.temporal_conv_size, + config=config, + ) + + self.image_preprocess = None + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def add_image_preprocess(self, processor): + """add image preprocess""" + logger.info("image preprocess is set") + + image_preprocess = processor.image_processor + image_preprocess.image_mean_tensor = torch.tensor( + image_preprocess.image_mean, dtype=torch.float32 + ).reshape([1, 3, 1, 1]) + image_preprocess.image_std_tensor = torch.tensor( + image_preprocess.image_std, dtype=torch.float32 + ).reshape([1, 3, 1, 1]) + image_preprocess.rescale_factor = torch.tensor( + image_preprocess.rescale_factor, dtype=torch.float32 + ) + image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze( + [-2, -1] + ).repeat_interleave(self.config.vision_config.patch_size**2 * 1, -1) + image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze( + [-2, -1] + ).repeat_interleave(self.config.vision_config.patch_size**2 * 1, -1) + + self.image_preprocess = image_preprocess + + def vision_forward( + self, + images, + image_position_ids, + image_attention_mask, + grid_thw, + ): + """vision_forward""" + if self.image_preprocess is not None: + assert images.dtype == torch.uint8, images.dtype + current_device = images.device + self.image_preprocess.image_mean_tensor = ( + self.image_preprocess.image_mean_tensor.to(current_device) + ) + self.image_preprocess.image_std_tensor = ( + self.image_preprocess.image_std_tensor.to(current_device) + ) + images = self.image_preprocess.rescale_factor * images.to(torch.float32) + images = ( + images - self.image_preprocess.image_mean_tensor + ) / self.image_preprocess.image_std_tensor + images = images.to(torch.bfloat16) + else: + assert images.dtype == torch.bfloat16, images.dtype + # logger.info(f"extract feature input - {images}--{grid_thw}") + if grid_thw is not None: + grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3]) + grid_thw = F.pad( + torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), + [1, 0, 0, 0], + value=1, + ) + image_features = self.vision_model(images, grid_thw) + return image_features + + def vision_mapping_forward( + self, + token_type_ids, + token_type_ids_w_video, + input_ids, + mm_input_ids, + image_features, + inputs_embeds, + image_type_ids, + grid_thw, + ): + """vision_mapping_forward""" + image_mask = input_ids == self.config.im_patch_id + image_features = self.model.resampler_model( + image_features, + image_mask, + token_type_ids_w_video, + image_type_ids, + grid_thw, + ) + + if image_features.dim == 2: + B, N, C = image_features.shape + image_features = image_features.reshape([B * N, C]).to(inputs_embeds.dtype) + # Will overwrite the part of `ids==im_patch_id` in `mm_ids_features` + inputs_embeds[image_mask.to(inputs_embeds.device)] = image_features.to( + inputs_embeds.device + ) + return inputs_embeds + + def prepare_inputs_for_generation( + self, + input_ids, + images=None, + use_cache=False, + past_key_values=None, + inputs_embeds=None, + image_position_ids=None, + image_attention_mask=None, + token_type_ids=None, + image_type_ids=None, + grid_thw=None, + **kwargs, + ): + """ + Prepare inputs for the decoder that can be used for generation. + + Args: + input_ids (torch.Tensor): Input ids. + images (torch.Tensor): Images. Default to None. + use_cache (bool): Whether to use cache. Default to False. + past_key_values (list): Past key values. Default to None. + inputs_embeds (torch.Tensor): Input embeddings. Default to None. + image_position_ids (torch.Tensor): Image position ids. Default to None. + image_attention_mask (torch.Tensor): Image attention mask. Default to None. + token_type_ids (torch.Tensor): Token type ids. Default to None. + image_type_ids (torch.Tensor): Image type ids. Default to None. + grid_thw (torch.Tensor): Grid thw. Default to None. + """ + if past_key_values: + input_ids = input_ids[:, -1:] + token_type_ids = token_type_ids[:, -1:] + image_type_ids = ( + image_type_ids[:, -1:] if image_type_ids is not None else None + ) + + attention_mask = kwargs.get("attention_mask", None) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": True, + "attention_mask": attention_mask, + "images": images, + "image_position_ids": image_position_ids, + "image_attention_mask": image_attention_mask, + "image_type_ids": image_type_ids, + "token_type_ids": torch.cat( + [ + token_type_ids, + torch.zeros( + [len(token_type_ids), 1], dtype=token_type_ids.dtype + ).to(token_type_ids.device), + ], + dim=-1, + ), + "grid_thw": grid_thw, + } + ) + if self.config.rope_3d: + model_inputs.update({"position_ids": kwargs["position_ids"]}) + + return model_inputs + + def _post_init(self, original_init, *args, **kwargs): + """ + Label all multimodal parameters in the model, only head and Embedding + Experts parameters are already labeled + """ + super()._post_init(self, original_init, *args, **kwargs) + if self.lm_head.mm_head is not None: + self.lm_head.mm_head.weight.expert_type = "expert_type_1" + if getattr(self.lm_head.mm_head, "bias", None) is not None: + self.lm_head.mm_head.bias.expert_type = "expert_type_1" + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + ignored_index: Optional[int] = 0, + return_dict: Optional[bool] = None, + image_position_ids: Optional[torch.Tensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + image_type_ids: Optional[torch.Tensor] = None, + grid_thw: Optional[torch.Tensor] = None, + **kwargs, + ): + """ + Forward for Ernie4_5_VLMoeForConditionalGeneration + + Args: + input_ids (torch.Tensor): Input ids. + position_ids (Optional[torch.Tensor], optional): Position ids. Defaults to None. + attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. + past_key_values (Optional[List[torch.Tensor]], optional): Past key values. Defaults to None. + use_cache (Optional[bool], optional): Use cache. Defaults to None. + output_attentions (Optional[bool], optional): Output attentions. Defaults to None. + output_hidden_states (Optional[bool], optional): Output hidden states. Defaults to None. + labels (Optional[torch.Tensor], optional): Labels. Defaults to None. + images (Optional[torch.Tensor]): Images. Defaults to None. + ignored_index (Optional[int], optional): Ignored index. Defaults to 0. + return_dict (Optional[bool], optional): Return dict. Defaults to None. + image_position_ids (Optional[torch.Tensor], optional): Image position ids. Defaults to None. + image_attention_mask (Optional[torch.Tensor], optional): Image attention mask. Defaults to None. + token_type_ids (Optional[torch.Tensor], optional): Token type ids. Defaults to None. + image_type_ids (Optional[torch.Tensor], optional): Image type ids. Defaults to None. + grid_thw (Optional[torch.Tensor], optional): Grid thw. Defaults to None. + """ + if grid_thw is not None: + grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3]) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + image_mask = input_ids == self.config.im_patch_id + + image_rate = image_mask.to(torch.float32).mean() + + if past_key_values is None: + if images is not None: + assert (image_mask).any().item(), ( + image_mask.detach().cpu().numpy().tolist(), + input_ids.detach().cpu().numpy().tolist(), + self.config.im_patch_id, + images.shape, + ) + image_features = self.vision_forward( + images, + image_position_ids, + image_attention_mask, + grid_thw, + ) + else: + image_features = None # no more faking + else: + image_features = None + if token_type_ids is None: + token_type_ids = image_mask.to(torch.int64) + token_type_ids_labels = torch.cat( + [token_type_ids[:, 1:], token_type_ids[:, -1:]], 1 + ) + else: + assert ( + token_type_ids.shape[1] == input_ids.shape[1] + 1 + ), f"token_type:{token_type_ids.shape}, ids:{input_ids.shape}" + token_type_ids_labels = token_type_ids[..., 1:] + + lm_input_ids = input_ids.clone() + mm_input_ids = input_ids.clone() + + inputs_embeds = self.model.embed_tokens(lm_input_ids) + token_type_ids_w_video = token_type_ids[..., :-1].clone() + token_type_ids[token_type_ids == TokenType.video] = TokenType.image + + if images is not None and image_features is not None: + inputs_embeds = self.vision_mapping_forward( + token_type_ids, + token_type_ids_w_video, + input_ids, + mm_input_ids, + image_features, + inputs_embeds, + image_type_ids, + grid_thw, + ) + else: + pass # do nothing, should not hang under DygraphShardingOptimizerV2 + + outputs = self.model( + position_ids=position_ids, + attention_mask=None, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + if not use_cache: + assert outputs.last_hidden_state.shape[:2] == token_type_ids_labels.shape, ( + outputs.last_hidden_state.shape, + token_type_ids_labels.shape, + ) + if self.config.use_recompute_loss_fn: + logits = outputs.last_hidden_state + else: + logits = self.lm_head(outputs.last_hidden_state) + else: + logits = self.lm_head(outputs.last_hidden_state[:, -1:, :]) + + router_loss = outputs.router_loss + + # aka Generate Decoding + loss = None + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_loss=outputs.router_loss, + ) + + @staticmethod + def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False): + """_resolve_prefix_keys""" + # state_keys_map base to real + state_keys_map = {} + + state_keys_base = set(state_keys_base) + state_keys_real = set(state_keys_real) + + for key in state_keys_base: + for x in state_keys_real: + if "mm_embed_tokens" in x: + if "mm_embed_tokens" in key: + state_keys_map[key] = x + break + elif x.endswith(key): + state_keys_map[key] = x + break + if key not in state_keys_map: + if not ignore_error: + logger.error(f"could not find name {key} in loaded state dict!") + else: + state_keys_real.remove(state_keys_map[key]) + + return state_keys_map + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model outputs with past key values and cross attention layers, + with additional support for router components in mixture-of-experts models. + + This extends the base model output to include: + 1. Router-related outputs for expert selection + 2. Maintains all existing functionality from the parent class + """ + + last_hidden_state: Optional[Tuple[torch.Tensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None + hidden_states: Optional[Tuple[torch.Tensor]] = None + attentions: Optional[Tuple[torch.Tensor]] = None + cross_attentions: Optional[Tuple[torch.Tensor]] = None + router_loss: Optional[torch.Tensor] = None + gate_logits: Optional[Tuple[torch.Tensor]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` + is passed or when `config.output_hidden_states=True`): + Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or + when `config.output_attentions=True`): + Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_loss (Optional[torch.Tensor]): + The routing loss computed by the gating network in mixture-of-experts models. + This is typically the load balancing loss that encourages equal expert utilization. + None when not using mixture-of-experts routing. + """ + + loss: Optional[torch.Tensor] = None + logits: torch.Tensor = None + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None + hidden_states: Optional[Tuple[torch.Tensor]] = None + attentions: Optional[Tuple[torch.Tensor]] = None + router_loss: Optional[Tuple[torch.Tensor]] = None