Update modeling_hyperclovax.py
Browse filessupport hf version later than 4.54.0
- modeling_hyperclovax.py +11 -3
modeling_hyperclovax.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# coding=utf-8
|
2 |
# This file was created for the HyperCLOVA X SEED 14B Think architecture.
|
3 |
-
# partially copied and modified from
|
|
|
4 |
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
5 |
#
|
6 |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
@@ -43,7 +44,14 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
|
|
43 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
44 |
from transformers.processing_utils import Unpack
|
45 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
46 |
-
from transformers.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
from .configuration_hyperclovax import HyperCLOVAXConfig
|
48 |
if is_torch_flex_attn_available():
|
49 |
from torch.nn.attention.flex_attention import BlockMask
|
@@ -620,7 +628,7 @@ class HyperCLOVAXModel(HyperCLOVAXPreTrainedModel):
|
|
620 |
return causal_mask
|
621 |
|
622 |
|
623 |
-
class KwargsForCausalLM(FlashAttentionKwargs,
|
624 |
|
625 |
|
626 |
@auto_docstring
|
|
|
1 |
# coding=utf-8
|
2 |
# This file was created for the HyperCLOVA X SEED 14B Think architecture.
|
3 |
+
# partially copied and modified from
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama/modeling_llama.py
|
5 |
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
6 |
#
|
7 |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
|
44 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
45 |
from transformers.processing_utils import Unpack
|
46 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
47 |
+
from transformers.utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
48 |
+
try:
|
49 |
+
from transformers.utils import LossKwargs
|
50 |
+
loss_kwargs_class = LossKwargs
|
51 |
+
except ImportError:
|
52 |
+
from transformers.utils import TransformersKwargs
|
53 |
+
loss_kwargs_class = TransformersKwargs
|
54 |
+
|
55 |
from .configuration_hyperclovax import HyperCLOVAXConfig
|
56 |
if is_torch_flex_attn_available():
|
57 |
from torch.nn.attention.flex_attention import BlockMask
|
|
|
628 |
return causal_mask
|
629 |
|
630 |
|
631 |
+
class KwargsForCausalLM(FlashAttentionKwargs, loss_kwargs_class): ...
|
632 |
|
633 |
|
634 |
@auto_docstring
|