Kamichanw commited on
Commit
0606604
·
1 Parent(s): 9275bf8

fix attn mask & chat template

Browse files
Files changed (2) hide show
  1. modeling_llada.py +15 -11
  2. tokenizer_config.json +1 -1
modeling_llada.py CHANGED
@@ -654,7 +654,7 @@ class LLaDABlock(nn.Module):
654
  q,
655
  k,
656
  v,
657
- attn_mask=None,
658
  dropout_p=dropout_p,
659
  is_causal=False,
660
  )
@@ -665,6 +665,7 @@ class LLaDABlock(nn.Module):
665
  k: torch.Tensor,
666
  v: torch.Tensor,
667
  attention_bias: Optional[torch.Tensor] = None,
 
668
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
669
  use_cache: bool = False,
670
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
@@ -712,7 +713,7 @@ class LLaDABlock(nn.Module):
712
  q,
713
  k,
714
  v,
715
- attn_mask=None,
716
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
717
  is_causal=False,
718
  )
@@ -785,6 +786,7 @@ class LLaDASequentialBlock(LLaDABlock):
785
  self,
786
  x: torch.Tensor,
787
  attention_bias: Optional[torch.Tensor] = None,
 
788
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
789
  use_cache: bool = False,
790
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
@@ -805,10 +807,10 @@ class LLaDASequentialBlock(LLaDABlock):
805
  # Get attention scores.
806
  if self._activation_checkpoint_fn is not None:
807
  att, cache = self._activation_checkpoint_fn( # type: ignore
808
- self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
809
  )
810
  else:
811
- att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
812
 
813
  # Add attention scores.
814
  # shape: (B, T, C)
@@ -887,6 +889,7 @@ class LLaDALlamaBlock(LLaDABlock):
887
  self,
888
  x: torch.Tensor,
889
  attention_bias: Optional[torch.Tensor] = None,
 
890
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
891
  use_cache: bool = False,
892
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
@@ -905,10 +908,10 @@ class LLaDALlamaBlock(LLaDABlock):
905
  # Get attention scores.
906
  if self._activation_checkpoint_fn is not None:
907
  att, cache = self._activation_checkpoint_fn( # type: ignore
908
- self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
909
  )
910
  else:
911
- att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
912
 
913
  # Add attention scores.
914
  # shape: (B, T, C)
@@ -977,6 +980,7 @@ class LLaDABlockGroup(nn.ModuleList):
977
  self,
978
  x: torch.Tensor,
979
  attention_bias: Optional[torch.FloatTensor] = None,
 
980
  layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
981
  use_cache: bool = False,
982
  ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
@@ -1001,11 +1005,11 @@ class LLaDABlockGroup(nn.ModuleList):
1001
  ):
1002
  # shape: (batch_size, seq_len, d_model)
1003
  x, cache = self._activation_checkpoint_fn( # type: ignore
1004
- block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1005
  )
1006
  else:
1007
  # shape: (batch_size, seq_len, d_model)
1008
- x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1009
  if attn_key_values is not None:
1010
  assert cache is not None
1011
  attn_key_values.append(cache)
@@ -1308,11 +1312,11 @@ class LLaDAModel(nn.Module):
1308
  ):
1309
  # shape: (batch_size, seq_len, d_model)
1310
  x, cache = self._activation_checkpoint_fn(
1311
- block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1312
  )
1313
  else:
1314
  # shape: (batch_size, seq_len, d_model)
1315
- x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1316
  if attn_key_values is not None:
1317
  assert cache is not None
1318
  attn_key_values.append(cache)
@@ -1330,7 +1334,7 @@ class LLaDAModel(nn.Module):
1330
  ]
1331
  )
1332
  x, cache = block_group(
1333
- x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
1334
  )
1335
  if attn_key_values is not None:
1336
  assert cache is not None
 
654
  q,
655
  k,
656
  v,
657
+ attn_mask=attn_mask,
658
  dropout_p=dropout_p,
659
  is_causal=False,
660
  )
 
665
  k: torch.Tensor,
666
  v: torch.Tensor,
667
  attention_bias: Optional[torch.Tensor] = None,
668
+ attention_mask: Optional[torch.Tensor] = None,
669
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
670
  use_cache: bool = False,
671
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
713
  q,
714
  k,
715
  v,
716
+ attn_mask=attention_mask,
717
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
718
  is_causal=False,
719
  )
 
786
  self,
787
  x: torch.Tensor,
788
  attention_bias: Optional[torch.Tensor] = None,
789
+ attention_mask: Optional[torch.Tensor] = None,
790
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
791
  use_cache: bool = False,
792
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
807
  # Get attention scores.
808
  if self._activation_checkpoint_fn is not None:
809
  att, cache = self._activation_checkpoint_fn( # type: ignore
810
+ self.attention, q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache
811
  )
812
  else:
813
+ att, cache = self.attention(q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache)
814
 
815
  # Add attention scores.
816
  # shape: (B, T, C)
 
889
  self,
890
  x: torch.Tensor,
891
  attention_bias: Optional[torch.Tensor] = None,
892
+ attention_mask: Optional[torch.Tensor] = None,
893
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
894
  use_cache: bool = False,
895
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
908
  # Get attention scores.
909
  if self._activation_checkpoint_fn is not None:
910
  att, cache = self._activation_checkpoint_fn( # type: ignore
911
+ self.attention, q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache
912
  )
913
  else:
914
+ att, cache = self.attention(q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache)
915
 
916
  # Add attention scores.
917
  # shape: (B, T, C)
 
980
  self,
981
  x: torch.Tensor,
982
  attention_bias: Optional[torch.FloatTensor] = None,
983
+ attention_mask: Optional[torch.FloatTensor] = None,
984
  layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
985
  use_cache: bool = False,
986
  ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
 
1005
  ):
1006
  # shape: (batch_size, seq_len, d_model)
1007
  x, cache = self._activation_checkpoint_fn( # type: ignore
1008
+ block, x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache
1009
  )
1010
  else:
1011
  # shape: (batch_size, seq_len, d_model)
1012
+ x, cache = block(x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache)
1013
  if attn_key_values is not None:
1014
  assert cache is not None
1015
  attn_key_values.append(cache)
 
1312
  ):
1313
  # shape: (batch_size, seq_len, d_model)
1314
  x, cache = self._activation_checkpoint_fn(
1315
+ block, x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache
1316
  )
1317
  else:
1318
  # shape: (batch_size, seq_len, d_model)
1319
+ x, cache = block(x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache)
1320
  if attn_key_values is not None:
1321
  assert cache is not None
1322
  attn_key_values.append(cache)
 
1334
  ]
1335
  )
1336
  x, cache = block_group(
1337
+ x, attention_bias=attention_bias, attention_mask=attention_mask, layers_past=layers_past, use_cache=use_cache
1338
  )
1339
  if attn_key_values is not None:
1340
  assert cache is not None
tokenizer_config.json CHANGED
@@ -2164,7 +2164,7 @@
2164
  "<|number_end|>"
2165
  ],
2166
  "bos_token": "<|startoftext|>",
2167
- "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
2168
  "clean_up_tokenization_spaces": false,
2169
  "cls_token": "[CLS]",
2170
  "eos_token": "<|endoftext|>",
 
2164
  "<|number_end|>"
2165
  ],
2166
  "bos_token": "<|startoftext|>",
2167
+ "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{%- if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{%- endif %}",
2168
  "clean_up_tokenization_spaces": false,
2169
  "cls_token": "[CLS]",
2170
  "eos_token": "<|endoftext|>",