Update modeling_hf_nomic_bert.py
Browse files- modeling_hf_nomic_bert.py +22 -6
modeling_hf_nomic_bert.py
CHANGED
|
@@ -16,7 +16,7 @@ from einops import rearrange, repeat
|
|
| 16 |
from transformers import GPT2Config, PreTrainedModel
|
| 17 |
from transformers.models.bert.modeling_bert import (
|
| 18 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 19 |
-
|
| 20 |
SequenceClassifierOutput
|
| 21 |
)
|
| 22 |
|
|
@@ -321,7 +321,10 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
| 321 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
| 322 |
num_labels = kwargs.pop("num_labels", None)
|
| 323 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
| 325 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
| 326 |
config.n_positions = 2048
|
| 327 |
if num_labels:
|
|
@@ -330,7 +333,10 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
| 330 |
if "add_pooling_layer" in kwargs:
|
| 331 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
| 332 |
else:
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
| 334 |
# TODO: fix this
|
| 335 |
# Assuming we know what we're doing when loading from disk
|
| 336 |
# Prob a bad assumption but i'm tired and want to train this asap
|
|
@@ -551,6 +557,12 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
| 551 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 552 |
self.interleaved = interleaved
|
| 553 |
self.scale_base = scale_base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
self._seq_len_cached = 0
|
| 556 |
self._cos_cached = None
|
|
@@ -616,7 +628,9 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
| 616 |
Apply rotary embedding *inplace* to qkv and / or kv.
|
| 617 |
"""
|
| 618 |
seqlen = qkv.shape[1]
|
| 619 |
-
if
|
|
|
|
|
|
|
| 620 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 621 |
elif isinstance(seqlen_offset, int):
|
| 622 |
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
|
@@ -1133,9 +1147,11 @@ class NomicBertForPreTraining(NomicBertPreTrainedModel):
|
|
| 1133 |
)
|
| 1134 |
total_loss = masked_lm_loss.float()
|
| 1135 |
|
| 1136 |
-
return
|
| 1137 |
loss=total_loss,
|
| 1138 |
-
|
|
|
|
|
|
|
| 1139 |
)
|
| 1140 |
|
| 1141 |
|
|
|
|
| 16 |
from transformers import GPT2Config, PreTrainedModel
|
| 17 |
from transformers.models.bert.modeling_bert import (
|
| 18 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 19 |
+
MaskedLMOutput,
|
| 20 |
SequenceClassifierOutput
|
| 21 |
)
|
| 22 |
|
|
|
|
| 321 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
| 322 |
num_labels = kwargs.pop("num_labels", None)
|
| 323 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
| 324 |
+
if rotary_scaling_factor:
|
| 325 |
+
config.rotary_scaling_factor = rotary_scaling_factor
|
| 326 |
+
else:
|
| 327 |
+
config.rotary_scaling_factor = None
|
| 328 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
| 329 |
config.n_positions = 2048
|
| 330 |
if num_labels:
|
|
|
|
| 333 |
if "add_pooling_layer" in kwargs:
|
| 334 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
| 335 |
else:
|
| 336 |
+
if cls == NomicBertModel:
|
| 337 |
+
model = cls(config, *inputs, add_pooling_layer=False)
|
| 338 |
+
else:
|
| 339 |
+
model = cls(config, *inputs)
|
| 340 |
# TODO: fix this
|
| 341 |
# Assuming we know what we're doing when loading from disk
|
| 342 |
# Prob a bad assumption but i'm tired and want to train this asap
|
|
|
|
| 557 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 558 |
self.interleaved = interleaved
|
| 559 |
self.scale_base = scale_base
|
| 560 |
+
scale = (
|
| 561 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
| 562 |
+
if scale_base is not None
|
| 563 |
+
else None
|
| 564 |
+
)
|
| 565 |
+
self.register_buffer("scale", scale, persistent=False)
|
| 566 |
|
| 567 |
self._seq_len_cached = 0
|
| 568 |
self._cos_cached = None
|
|
|
|
| 628 |
Apply rotary embedding *inplace* to qkv and / or kv.
|
| 629 |
"""
|
| 630 |
seqlen = qkv.shape[1]
|
| 631 |
+
if seqlen > self._seq_len_cached:
|
| 632 |
+
self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 633 |
+
elif max_seqlen is not None:
|
| 634 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 635 |
elif isinstance(seqlen_offset, int):
|
| 636 |
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
|
|
|
| 1147 |
)
|
| 1148 |
total_loss = masked_lm_loss.float()
|
| 1149 |
|
| 1150 |
+
return MaskedLMOutput(
|
| 1151 |
loss=total_loss,
|
| 1152 |
+
logits=prediction_scores,
|
| 1153 |
+
hidden_states=outputs.hidden_states,
|
| 1154 |
+
attentions=None,
|
| 1155 |
)
|
| 1156 |
|
| 1157 |
|