feat: reverted monkey patch
Browse files- configuration_bert.py +0 -2
- modeling_bert.py +5 -17
    	
        configuration_bert.py
    CHANGED
    
    | @@ -14,8 +14,6 @@ | |
| 14 | 
             
            # See the License for the specific language governing permissions and
         | 
| 15 | 
             
            # limitations under the License.
         | 
| 16 | 
             
            """ BERT model configuration"""
         | 
| 17 | 
            -
            from collections import OrderedDict
         | 
| 18 | 
            -
            from typing import Mapping
         | 
| 19 |  | 
| 20 | 
             
            from transformers import PretrainedConfig
         | 
| 21 |  | 
|  | |
| 14 | 
             
            # See the License for the specific language governing permissions and
         | 
| 15 | 
             
            # limitations under the License.
         | 
| 16 | 
             
            """ BERT model configuration"""
         | 
|  | |
|  | |
| 17 |  | 
| 18 | 
             
            from transformers import PretrainedConfig
         | 
| 19 |  | 
    	
        modeling_bert.py
    CHANGED
    
    | @@ -28,16 +28,13 @@ from transformers.models.bert.modeling_bert import ( | |
| 28 | 
             
                BaseModelOutputWithPoolingAndCrossAttentions,
         | 
| 29 | 
             
                BertForPreTrainingOutput,
         | 
| 30 | 
             
            )
         | 
| 31 | 
            -
            from .patched_padding_bert import index_first_axis as index_first_axis_monkey_patch
         | 
| 32 | 
            -
            import flash_attn.bert_padding
         | 
| 33 | 
            -
            flash_attn.bert_padding.index_first_axis = index_first_axis_monkey_patch
         | 
| 34 | 
            -
            """
         | 
| 35 | 
             
            from flash_attn.bert_padding import (
         | 
|  | |
| 36 | 
             
                index_first_axis_residual,
         | 
| 37 | 
             
                pad_input,
         | 
| 38 | 
             
                unpad_input,
         | 
| 39 | 
             
            )
         | 
| 40 | 
            -
             | 
| 41 | 
             
            from flash_attn.modules.block import Block
         | 
| 42 | 
             
            from flash_attn.modules.embedding import BertEmbeddings
         | 
| 43 | 
             
            from flash_attn.modules.mha import MHA
         | 
| @@ -176,14 +173,14 @@ class BertEncoder(nn.Module): | |
| 176 | 
             
                            hidden_states = hidden_states[subset_mask]
         | 
| 177 | 
             
                    else:
         | 
| 178 | 
             
                        batch, seqlen = hidden_states.shape[:2]
         | 
| 179 | 
            -
                        hidden_states, indices, cu_seqlens, max_seqlen_in_batch =  | 
| 180 | 
             
                            hidden_states, key_padding_mask
         | 
| 181 | 
             
                        )
         | 
| 182 | 
             
                        mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
         | 
| 183 | 
             
                        if subset_mask is None:
         | 
| 184 | 
             
                            for layer in self.layers:
         | 
| 185 | 
             
                                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
         | 
| 186 | 
            -
                            hidden_states =  | 
| 187 | 
             
                        else:
         | 
| 188 | 
             
                            for layer in self.layers[:-1]:
         | 
| 189 | 
             
                                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
         | 
| @@ -201,7 +198,7 @@ class BertEncoder(nn.Module): | |
| 201 | 
             
                                subset_cu_seqlens = F.pad(
         | 
| 202 | 
             
                                    torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
         | 
| 203 | 
             
                                )
         | 
| 204 | 
            -
                            hidden_states_subset, hidden_states =  | 
| 205 | 
             
                                hidden_states, subset_idx
         | 
| 206 | 
             
                            )
         | 
| 207 | 
             
                            # It's ok to set max_seqlen_q to be much larger
         | 
| @@ -425,15 +422,6 @@ class BertModel(BertPreTrainedModel): | |
| 425 | 
             
                        pooler_output=pooled_output,
         | 
| 426 | 
             
                    )
         | 
| 427 |  | 
| 428 | 
            -
                def to(self, *args, **kwargs):
         | 
| 429 | 
            -
                    print(f'In BERT, calling to({args, kwargs})')
         | 
| 430 | 
            -
                    result = super().to(*args, **kwargs)
         | 
| 431 | 
            -
                    if (len(args) > 0 and isinstance(args[0], torch.dtype)) or "dtype" in kwargs:
         | 
| 432 | 
            -
                        for layer in result.encoder.layers:
         | 
| 433 | 
            -
                            layer.mixer.inner_cross_attn.alibi_slopes = layer.mixer.inner_cross_attn.alibi_slopes.to(torch.float32)
         | 
| 434 | 
            -
                            layer.mixer.inner_attn.alibi_slopes = layer.mixer.inner_attn.alibi_slopes.to(torch.float32)
         | 
| 435 | 
            -
                    return result
         | 
| 436 | 
            -
             | 
| 437 |  | 
| 438 | 
             
            class BertForPreTraining(BertPreTrainedModel):
         | 
| 439 | 
             
                def __init__(self, config: JinaBertConfig):
         | 
|  | |
| 28 | 
             
                BaseModelOutputWithPoolingAndCrossAttentions,
         | 
| 29 | 
             
                BertForPreTrainingOutput,
         | 
| 30 | 
             
            )
         | 
|  | |
|  | |
|  | |
|  | |
| 31 | 
             
            from flash_attn.bert_padding import (
         | 
| 32 | 
            +
                index_first_axis,
         | 
| 33 | 
             
                index_first_axis_residual,
         | 
| 34 | 
             
                pad_input,
         | 
| 35 | 
             
                unpad_input,
         | 
| 36 | 
             
            )
         | 
| 37 | 
            +
             | 
| 38 | 
             
            from flash_attn.modules.block import Block
         | 
| 39 | 
             
            from flash_attn.modules.embedding import BertEmbeddings
         | 
| 40 | 
             
            from flash_attn.modules.mha import MHA
         | 
|  | |
| 173 | 
             
                            hidden_states = hidden_states[subset_mask]
         | 
| 174 | 
             
                    else:
         | 
| 175 | 
             
                        batch, seqlen = hidden_states.shape[:2]
         | 
| 176 | 
            +
                        hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
         | 
| 177 | 
             
                            hidden_states, key_padding_mask
         | 
| 178 | 
             
                        )
         | 
| 179 | 
             
                        mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
         | 
| 180 | 
             
                        if subset_mask is None:
         | 
| 181 | 
             
                            for layer in self.layers:
         | 
| 182 | 
             
                                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
         | 
| 183 | 
            +
                            hidden_states = pad_input(hidden_states, indices, batch, seqlen)
         | 
| 184 | 
             
                        else:
         | 
| 185 | 
             
                            for layer in self.layers[:-1]:
         | 
| 186 | 
             
                                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
         | 
|  | |
| 198 | 
             
                                subset_cu_seqlens = F.pad(
         | 
| 199 | 
             
                                    torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
         | 
| 200 | 
             
                                )
         | 
| 201 | 
            +
                            hidden_states_subset, hidden_states = index_first_axis_residual(
         | 
| 202 | 
             
                                hidden_states, subset_idx
         | 
| 203 | 
             
                            )
         | 
| 204 | 
             
                            # It's ok to set max_seqlen_q to be much larger
         | 
|  | |
| 422 | 
             
                        pooler_output=pooled_output,
         | 
| 423 | 
             
                    )
         | 
| 424 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 425 |  | 
| 426 | 
             
            class BertForPreTraining(BertPreTrainedModel):
         | 
| 427 | 
             
                def __init__(self, config: JinaBertConfig):
         | 
