feat: made from_bert work
Browse files- modeling_lora.py +11 -5
modeling_lora.py
CHANGED
|
@@ -174,18 +174,24 @@ class LoRAParametrization(nn.Module):
|
|
| 174 |
|
| 175 |
|
| 176 |
class BertLoRA(BertPreTrainedModel):
|
| 177 |
-
def __init__(self, config: JinaBertConfig, add_pooling_layer=True, num_adaptions=1):
|
| 178 |
super().__init__(config)
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
| 180 |
self._register_lora(num_adaptions)
|
| 181 |
for name, param in super().named_parameters():
|
| 182 |
if "lora" not in name:
|
| 183 |
param.requires_grad_(False)
|
| 184 |
self.select_task(0)
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 191 |
self.apply(
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
class BertLoRA(BertPreTrainedModel):
|
| 177 |
+
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True, num_adaptions=1):
|
| 178 |
super().__init__(config)
|
| 179 |
+
if bert is None:
|
| 180 |
+
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
| 181 |
+
else:
|
| 182 |
+
self.bert = bert
|
| 183 |
self._register_lora(num_adaptions)
|
| 184 |
for name, param in super().named_parameters():
|
| 185 |
if "lora" not in name:
|
| 186 |
param.requires_grad_(False)
|
| 187 |
self.select_task(0)
|
| 188 |
|
| 189 |
+
@classmethod
|
| 190 |
+
def from_bert(cls, *args, num_adaptions=1, **kwargs):
|
| 191 |
+
bert = BertModel.from_pretrained(*args, **kwargs)
|
| 192 |
+
config = JinaBertConfig.from_pretrained(*args, **kwargs)
|
| 193 |
+
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
| 194 |
+
|
| 195 |
|
| 196 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 197 |
self.apply(
|