feat: try to monkey-patch index_first_axis
Browse files- modeling_bert.py +8 -4
modeling_bert.py
CHANGED
|
@@ -28,12 +28,16 @@ from transformers.models.bert.modeling_bert import (
|
|
| 28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 29 |
BertForPreTrainingOutput,
|
| 30 |
)
|
| 31 |
-
from .patched_padding_bert import index_first_axis
|
|
|
|
|
|
|
|
|
|
| 32 |
from flash_attn.bert_padding import (
|
| 33 |
index_first_axis_residual,
|
| 34 |
pad_input,
|
| 35 |
unpad_input,
|
| 36 |
)
|
|
|
|
| 37 |
from flash_attn.modules.block import Block
|
| 38 |
from flash_attn.modules.embedding import BertEmbeddings
|
| 39 |
from flash_attn.modules.mha import MHA
|
|
@@ -172,14 +176,14 @@ class BertEncoder(nn.Module):
|
|
| 172 |
hidden_states = hidden_states[subset_mask]
|
| 173 |
else:
|
| 174 |
batch, seqlen = hidden_states.shape[:2]
|
| 175 |
-
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
| 176 |
hidden_states, key_padding_mask
|
| 177 |
)
|
| 178 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 179 |
if subset_mask is None:
|
| 180 |
for layer in self.layers:
|
| 181 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 182 |
-
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 183 |
else:
|
| 184 |
for layer in self.layers[:-1]:
|
| 185 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
@@ -197,7 +201,7 @@ class BertEncoder(nn.Module):
|
|
| 197 |
subset_cu_seqlens = F.pad(
|
| 198 |
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
| 199 |
)
|
| 200 |
-
hidden_states_subset, hidden_states = index_first_axis_residual(
|
| 201 |
hidden_states, subset_idx
|
| 202 |
)
|
| 203 |
# It's ok to set max_seqlen_q to be much larger
|
|
|
|
| 28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 29 |
BertForPreTrainingOutput,
|
| 30 |
)
|
| 31 |
+
from .patched_padding_bert import index_first_axis as index_first_axis_monkey_patch
|
| 32 |
+
import flash_attn.bert_padding
|
| 33 |
+
flash_attn.bert_padding.index_first_axis = index_first_axis_monkey_patch
|
| 34 |
+
"""
|
| 35 |
from flash_attn.bert_padding import (
|
| 36 |
index_first_axis_residual,
|
| 37 |
pad_input,
|
| 38 |
unpad_input,
|
| 39 |
)
|
| 40 |
+
"""
|
| 41 |
from flash_attn.modules.block import Block
|
| 42 |
from flash_attn.modules.embedding import BertEmbeddings
|
| 43 |
from flash_attn.modules.mha import MHA
|
|
|
|
| 176 |
hidden_states = hidden_states[subset_mask]
|
| 177 |
else:
|
| 178 |
batch, seqlen = hidden_states.shape[:2]
|
| 179 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = flash_attn.bert_padding.unpad_input(
|
| 180 |
hidden_states, key_padding_mask
|
| 181 |
)
|
| 182 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 183 |
if subset_mask is None:
|
| 184 |
for layer in self.layers:
|
| 185 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 186 |
+
hidden_states = flash_attn.bert_padding.pad_input(hidden_states, indices, batch, seqlen)
|
| 187 |
else:
|
| 188 |
for layer in self.layers[:-1]:
|
| 189 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
| 201 |
subset_cu_seqlens = F.pad(
|
| 202 |
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
| 203 |
)
|
| 204 |
+
hidden_states_subset, hidden_states = flash_attn.bert_padding.index_first_axis_residual(
|
| 205 |
hidden_states, subset_idx
|
| 206 |
)
|
| 207 |
# It's ok to set max_seqlen_q to be much larger
|