Commit
·
f438eeb
1
Parent(s):
e27b4b2
Add loss function and bool cast
Browse files- modeling_mpt.py +12 -3
modeling_mpt.py
CHANGED
|
@@ -56,7 +56,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
| 56 |
for module in self.modules():
|
| 57 |
if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
|
| 58 |
if config.verbose:
|
| 59 |
-
|
| 60 |
module.register_parameter('bias', None)
|
| 61 |
if config.verbose and config.verbose > 2:
|
| 62 |
print(self)
|
|
@@ -131,6 +131,10 @@ class MPTModel(MPTPreTrainedModel):
|
|
| 131 |
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
| 132 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 133 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
if not return_dict:
|
| 135 |
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
| 136 |
if output_attentions:
|
|
@@ -228,7 +232,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
| 228 |
def get_decoder(self):
|
| 229 |
return self.transformer
|
| 230 |
|
| 231 |
-
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
| 232 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 233 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 234 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
|
@@ -237,7 +241,12 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
| 237 |
if self.logit_scale == 0:
|
| 238 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
| 239 |
logits *= self.logit_scale
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
def param_init_fn(self, module):
|
| 243 |
init_fn_name = self.config.init_config['name']
|
|
|
|
| 56 |
for module in self.modules():
|
| 57 |
if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
|
| 58 |
if config.verbose:
|
| 59 |
+
warnings.warn(f'Removing bias ({module.bias}) from {module}.')
|
| 60 |
module.register_parameter('bias', None)
|
| 61 |
if config.verbose and config.verbose > 2:
|
| 62 |
print(self)
|
|
|
|
| 131 |
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
| 132 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 133 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 134 |
+
if attention_mask is not None:
|
| 135 |
+
attention_mask = attention_mask.bool()
|
| 136 |
+
if prefix_mask is not None:
|
| 137 |
+
prefix_mask = prefix_mask.bool()
|
| 138 |
if not return_dict:
|
| 139 |
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
| 140 |
if output_attentions:
|
|
|
|
| 232 |
def get_decoder(self):
|
| 233 |
return self.transformer
|
| 234 |
|
| 235 |
+
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
| 236 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 237 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 238 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
|
|
|
| 241 |
if self.logit_scale == 0:
|
| 242 |
warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
|
| 243 |
logits *= self.logit_scale
|
| 244 |
+
loss = None
|
| 245 |
+
if labels is not None:
|
| 246 |
+
labels = torch.roll(labels, shifts=-1)
|
| 247 |
+
labels[:, -1] = -100
|
| 248 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
| 249 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
|
| 250 |
|
| 251 |
def param_init_fn(self, module):
|
| 252 |
init_fn_name = self.config.init_config['name']
|