Upload ModularStarEncoder
Browse files- modularStarEncoder.py +6 -3
modularStarEncoder.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from transformers import Starcoder2Model
|
| 2 |
import sys
|
| 3 |
-
from
|
| 4 |
import os
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Optional, Tuple, Union
|
|
@@ -171,8 +171,11 @@ class StarEncoder2PreTrainingHeads(nn.Module):
|
|
| 171 |
|
| 172 |
def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None):
|
| 173 |
if self.is_matryoshka:
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
| 176 |
else:
|
| 177 |
prediction_scores = self.predictions(sequence_output)
|
| 178 |
seq_relationship_score = self.seq_relationship(pooled_output)
|
|
|
|
| 1 |
from transformers import Starcoder2Model
|
| 2 |
import sys
|
| 3 |
+
from config import ModularStarEncoderConfig
|
| 4 |
import os
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Optional, Tuple, Union
|
|
|
|
| 171 |
|
| 172 |
def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None):
|
| 173 |
if self.is_matryoshka:
|
| 174 |
+
device_sequence = sequence_output.get_device()
|
| 175 |
+
if device_sequence<0:
|
| 176 |
+
device_sequence = "cpu"
|
| 177 |
+
prediction_scores = self.predictions(torch.cat([sequence_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(sequence_output.size()[0],sequence_output.size()[1],-1)],dim=-1))
|
| 178 |
+
seq_relationship_score = self.seq_relationship(torch.cat([pooled_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(pooled_output.size()[0],-1)],dim=-1))
|
| 179 |
else:
|
| 180 |
prediction_scores = self.predictions(sequence_output)
|
| 181 |
seq_relationship_score = self.seq_relationship(pooled_output)
|