fix attn mask & chat template
Browse files- modeling_llada.py +15 -11
- tokenizer_config.json +1 -1
modeling_llada.py
CHANGED
@@ -654,7 +654,7 @@ class LLaDABlock(nn.Module):
|
|
654 |
q,
|
655 |
k,
|
656 |
v,
|
657 |
-
attn_mask=
|
658 |
dropout_p=dropout_p,
|
659 |
is_causal=False,
|
660 |
)
|
@@ -665,6 +665,7 @@ class LLaDABlock(nn.Module):
|
|
665 |
k: torch.Tensor,
|
666 |
v: torch.Tensor,
|
667 |
attention_bias: Optional[torch.Tensor] = None,
|
|
|
668 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
669 |
use_cache: bool = False,
|
670 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
@@ -712,7 +713,7 @@ class LLaDABlock(nn.Module):
|
|
712 |
q,
|
713 |
k,
|
714 |
v,
|
715 |
-
attn_mask=
|
716 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
717 |
is_causal=False,
|
718 |
)
|
@@ -785,6 +786,7 @@ class LLaDASequentialBlock(LLaDABlock):
|
|
785 |
self,
|
786 |
x: torch.Tensor,
|
787 |
attention_bias: Optional[torch.Tensor] = None,
|
|
|
788 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
789 |
use_cache: bool = False,
|
790 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
@@ -805,10 +807,10 @@ class LLaDASequentialBlock(LLaDABlock):
|
|
805 |
# Get attention scores.
|
806 |
if self._activation_checkpoint_fn is not None:
|
807 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
808 |
-
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
809 |
)
|
810 |
else:
|
811 |
-
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
812 |
|
813 |
# Add attention scores.
|
814 |
# shape: (B, T, C)
|
@@ -887,6 +889,7 @@ class LLaDALlamaBlock(LLaDABlock):
|
|
887 |
self,
|
888 |
x: torch.Tensor,
|
889 |
attention_bias: Optional[torch.Tensor] = None,
|
|
|
890 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
891 |
use_cache: bool = False,
|
892 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
@@ -905,10 +908,10 @@ class LLaDALlamaBlock(LLaDABlock):
|
|
905 |
# Get attention scores.
|
906 |
if self._activation_checkpoint_fn is not None:
|
907 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
908 |
-
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
909 |
)
|
910 |
else:
|
911 |
-
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
912 |
|
913 |
# Add attention scores.
|
914 |
# shape: (B, T, C)
|
@@ -977,6 +980,7 @@ class LLaDABlockGroup(nn.ModuleList):
|
|
977 |
self,
|
978 |
x: torch.Tensor,
|
979 |
attention_bias: Optional[torch.FloatTensor] = None,
|
|
|
980 |
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
981 |
use_cache: bool = False,
|
982 |
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
@@ -1001,11 +1005,11 @@ class LLaDABlockGroup(nn.ModuleList):
|
|
1001 |
):
|
1002 |
# shape: (batch_size, seq_len, d_model)
|
1003 |
x, cache = self._activation_checkpoint_fn( # type: ignore
|
1004 |
-
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
1005 |
)
|
1006 |
else:
|
1007 |
# shape: (batch_size, seq_len, d_model)
|
1008 |
-
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
1009 |
if attn_key_values is not None:
|
1010 |
assert cache is not None
|
1011 |
attn_key_values.append(cache)
|
@@ -1308,11 +1312,11 @@ class LLaDAModel(nn.Module):
|
|
1308 |
):
|
1309 |
# shape: (batch_size, seq_len, d_model)
|
1310 |
x, cache = self._activation_checkpoint_fn(
|
1311 |
-
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
1312 |
)
|
1313 |
else:
|
1314 |
# shape: (batch_size, seq_len, d_model)
|
1315 |
-
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
1316 |
if attn_key_values is not None:
|
1317 |
assert cache is not None
|
1318 |
attn_key_values.append(cache)
|
@@ -1330,7 +1334,7 @@ class LLaDAModel(nn.Module):
|
|
1330 |
]
|
1331 |
)
|
1332 |
x, cache = block_group(
|
1333 |
-
x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
|
1334 |
)
|
1335 |
if attn_key_values is not None:
|
1336 |
assert cache is not None
|
|
|
654 |
q,
|
655 |
k,
|
656 |
v,
|
657 |
+
attn_mask=attn_mask,
|
658 |
dropout_p=dropout_p,
|
659 |
is_causal=False,
|
660 |
)
|
|
|
665 |
k: torch.Tensor,
|
666 |
v: torch.Tensor,
|
667 |
attention_bias: Optional[torch.Tensor] = None,
|
668 |
+
attention_mask: Optional[torch.Tensor] = None,
|
669 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
670 |
use_cache: bool = False,
|
671 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
713 |
q,
|
714 |
k,
|
715 |
v,
|
716 |
+
attn_mask=attention_mask,
|
717 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
718 |
is_causal=False,
|
719 |
)
|
|
|
786 |
self,
|
787 |
x: torch.Tensor,
|
788 |
attention_bias: Optional[torch.Tensor] = None,
|
789 |
+
attention_mask: Optional[torch.Tensor] = None,
|
790 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
791 |
use_cache: bool = False,
|
792 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
807 |
# Get attention scores.
|
808 |
if self._activation_checkpoint_fn is not None:
|
809 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
810 |
+
self.attention, q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache
|
811 |
)
|
812 |
else:
|
813 |
+
att, cache = self.attention(q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache)
|
814 |
|
815 |
# Add attention scores.
|
816 |
# shape: (B, T, C)
|
|
|
889 |
self,
|
890 |
x: torch.Tensor,
|
891 |
attention_bias: Optional[torch.Tensor] = None,
|
892 |
+
attention_mask: Optional[torch.Tensor] = None,
|
893 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
894 |
use_cache: bool = False,
|
895 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
908 |
# Get attention scores.
|
909 |
if self._activation_checkpoint_fn is not None:
|
910 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
911 |
+
self.attention, q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache
|
912 |
)
|
913 |
else:
|
914 |
+
att, cache = self.attention(q, k, v, attention_bias, attention_mask, layer_past=layer_past, use_cache=use_cache)
|
915 |
|
916 |
# Add attention scores.
|
917 |
# shape: (B, T, C)
|
|
|
980 |
self,
|
981 |
x: torch.Tensor,
|
982 |
attention_bias: Optional[torch.FloatTensor] = None,
|
983 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
984 |
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
985 |
use_cache: bool = False,
|
986 |
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
|
|
1005 |
):
|
1006 |
# shape: (batch_size, seq_len, d_model)
|
1007 |
x, cache = self._activation_checkpoint_fn( # type: ignore
|
1008 |
+
block, x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache
|
1009 |
)
|
1010 |
else:
|
1011 |
# shape: (batch_size, seq_len, d_model)
|
1012 |
+
x, cache = block(x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache)
|
1013 |
if attn_key_values is not None:
|
1014 |
assert cache is not None
|
1015 |
attn_key_values.append(cache)
|
|
|
1312 |
):
|
1313 |
# shape: (batch_size, seq_len, d_model)
|
1314 |
x, cache = self._activation_checkpoint_fn(
|
1315 |
+
block, x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache
|
1316 |
)
|
1317 |
else:
|
1318 |
# shape: (batch_size, seq_len, d_model)
|
1319 |
+
x, cache = block(x, attention_bias=attention_bias, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache)
|
1320 |
if attn_key_values is not None:
|
1321 |
assert cache is not None
|
1322 |
attn_key_values.append(cache)
|
|
|
1334 |
]
|
1335 |
)
|
1336 |
x, cache = block_group(
|
1337 |
+
x, attention_bias=attention_bias, attention_mask=attention_mask, layers_past=layers_past, use_cache=use_cache
|
1338 |
)
|
1339 |
if attn_key_values is not None:
|
1340 |
assert cache is not None
|
tokenizer_config.json
CHANGED
@@ -2164,7 +2164,7 @@
|
|
2164 |
"<|number_end|>"
|
2165 |
],
|
2166 |
"bos_token": "<|startoftext|>",
|
2167 |
-
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
2168 |
"clean_up_tokenization_spaces": false,
|
2169 |
"cls_token": "[CLS]",
|
2170 |
"eos_token": "<|endoftext|>",
|
|
|
2164 |
"<|number_end|>"
|
2165 |
],
|
2166 |
"bos_token": "<|startoftext|>",
|
2167 |
+
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{%- if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{%- endif %}",
|
2168 |
"clean_up_tokenization_spaces": false,
|
2169 |
"cls_token": "[CLS]",
|
2170 |
"eos_token": "<|endoftext|>",
|