Upload ModularStarEncoder
Browse files- modularStarEncoder.py +2 -2
modularStarEncoder.py
CHANGED
|
@@ -154,8 +154,8 @@ class StarEncoder2LMPredictionHead(nn.Module):
|
|
| 154 |
super().__init__()
|
| 155 |
for element in dir(config):
|
| 156 |
value = getattr(config, element) # Get the attribute value
|
| 157 |
-
if isinstance(value, tuple) or isinstance(value, list):
|
| 158 |
-
setattr(config, element, value[
|
| 159 |
self.transform = StarEncoder2PredictionHeadTransform(config)
|
| 160 |
|
| 161 |
# The output weights are the same as the input embeddings, but there is
|
|
|
|
| 154 |
super().__init__()
|
| 155 |
for element in dir(config):
|
| 156 |
value = getattr(config, element) # Get the attribute value
|
| 157 |
+
if (isinstance(value, tuple) or isinstance(value, list)) and len(value)>0:
|
| 158 |
+
setattr(config, element, value[0])
|
| 159 |
self.transform = StarEncoder2PredictionHeadTransform(config)
|
| 160 |
|
| 161 |
# The output weights are the same as the input embeddings, but there is
|