Upload ModularStarEncoder
Browse files- modularStarEncoder.py +5 -47
modularStarEncoder.py
CHANGED
|
@@ -1,39 +1,21 @@
|
|
| 1 |
from transformers import AutoConfig, Starcoder2Model, Starcoder2Config
|
| 2 |
import sys
|
|
|
|
| 3 |
import os
|
| 4 |
-
from .config import ModularStarEncoderConfig
|
| 5 |
-
import math
|
| 6 |
-
import os
|
| 7 |
-
import warnings
|
| 8 |
from dataclasses import dataclass
|
| 9 |
-
from typing import
|
| 10 |
import sys
|
| 11 |
import torch
|
| 12 |
import torch.utils.checkpoint
|
| 13 |
from torch import nn
|
| 14 |
-
from torch.nn import
|
| 15 |
-
|
| 16 |
from transformers.activations import ACT2FN
|
| 17 |
-
from transformers.modeling_outputs import (
|
| 18 |
-
BaseModelOutputWithPastAndCrossAttentions,
|
| 19 |
-
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 20 |
-
CausalLMOutputWithCrossAttentions,
|
| 21 |
-
MaskedLMOutput,
|
| 22 |
-
MultipleChoiceModelOutput,
|
| 23 |
-
NextSentencePredictorOutput,
|
| 24 |
-
QuestionAnsweringModelOutput,
|
| 25 |
-
SequenceClassifierOutput,
|
| 26 |
-
TokenClassifierOutput,
|
| 27 |
-
)
|
| 28 |
from transformers.modeling_utils import PreTrainedModel
|
| 29 |
-
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 30 |
from transformers.utils import (
|
| 31 |
ModelOutput,
|
| 32 |
-
|
| 33 |
-
add_start_docstrings,
|
| 34 |
-
add_start_docstrings_to_model_forward,
|
| 35 |
logging,
|
| 36 |
-
|
| 37 |
)
|
| 38 |
|
| 39 |
logger = logging.get_logger(__name__)
|
|
@@ -243,11 +225,7 @@ class ModularStarEncoder(StarEncoder2PreTrainedModel):
|
|
| 243 |
# Initialize weights and apply final processing
|
| 244 |
self.post_init()
|
| 245 |
|
| 246 |
-
# def get_output_embeddings(self):
|
| 247 |
-
# return self.cls.predictions.decoder
|
| 248 |
|
| 249 |
-
# def set_output_embeddings(self, new_embeddings):
|
| 250 |
-
# self.cls.predictions.decoder = new_embeddings
|
| 251 |
|
| 252 |
|
| 253 |
|
|
@@ -279,40 +257,20 @@ class ModularStarEncoder(StarEncoder2PreTrainedModel):
|
|
| 279 |
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
| 280 |
Used to hide legacy arguments that have been deprecated.
|
| 281 |
|
| 282 |
-
Returns:
|
| 283 |
|
| 284 |
-
Example:
|
| 285 |
-
|
| 286 |
-
```python
|
| 287 |
-
>>> from transformers import AutoTokenizer, BertForPreTraining
|
| 288 |
-
>>> import torch
|
| 289 |
-
|
| 290 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 291 |
-
>>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
|
| 292 |
-
|
| 293 |
-
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 294 |
-
>>> outputs = model(**inputs)
|
| 295 |
-
|
| 296 |
-
>>> prediction_logits = outputs.prediction_logits
|
| 297 |
-
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
| 298 |
-
```
|
| 299 |
"""
|
| 300 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 301 |
|
| 302 |
outputs = self.starEncoder2(
|
| 303 |
input_ids,
|
| 304 |
attention_mask=attention_mask,
|
| 305 |
-
# token_type_ids=token_type_ids,
|
| 306 |
position_ids=position_ids,
|
| 307 |
-
# head_mask=head_mask,
|
| 308 |
inputs_embeds=inputs_embeds,
|
| 309 |
output_attentions=output_attentions,
|
| 310 |
output_hidden_states=True,
|
| 311 |
return_dict=return_dict,
|
| 312 |
)
|
| 313 |
|
| 314 |
-
|
| 315 |
-
#TODO FIX FOR EFFICIENCY, COMPUTE FORWARD PASS JUST ON MATRYOSKA LAYERS
|
| 316 |
#if layer matryoshka on, compute the scores for all the heads
|
| 317 |
if self.layer_matryoshka_loss:
|
| 318 |
prediction_scores = []
|
|
|
|
| 1 |
from transformers import AutoConfig, Starcoder2Model, Starcoder2Config
|
| 2 |
import sys
|
| 3 |
+
from config import ModularStarEncoderConfig
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
+
from typing import Optional, Tuple, Union
|
| 7 |
import sys
|
| 8 |
import torch
|
| 9 |
import torch.utils.checkpoint
|
| 10 |
from torch import nn
|
| 11 |
+
from torch.nn import CrossEntropyLoss
|
|
|
|
| 12 |
from transformers.activations import ACT2FN
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 14 |
from transformers.utils import (
|
| 15 |
ModelOutput,
|
| 16 |
+
|
|
|
|
|
|
|
| 17 |
logging,
|
| 18 |
+
|
| 19 |
)
|
| 20 |
|
| 21 |
logger = logging.get_logger(__name__)
|
|
|
|
| 225 |
# Initialize weights and apply final processing
|
| 226 |
self.post_init()
|
| 227 |
|
|
|
|
|
|
|
| 228 |
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
|
|
|
|
| 257 |
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
| 258 |
Used to hide legacy arguments that have been deprecated.
|
| 259 |
|
|
|
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
"""
|
| 262 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 263 |
|
| 264 |
outputs = self.starEncoder2(
|
| 265 |
input_ids,
|
| 266 |
attention_mask=attention_mask,
|
|
|
|
| 267 |
position_ids=position_ids,
|
|
|
|
| 268 |
inputs_embeds=inputs_embeds,
|
| 269 |
output_attentions=output_attentions,
|
| 270 |
output_hidden_states=True,
|
| 271 |
return_dict=return_dict,
|
| 272 |
)
|
| 273 |
|
|
|
|
|
|
|
| 274 |
#if layer matryoshka on, compute the scores for all the heads
|
| 275 |
if self.layer_matryoshka_loss:
|
| 276 |
prediction_scores = []
|