Update modeling_hyperclovax.py (#6)
Browse files- Update modeling_hyperclovax.py (1f212e3e2f00d627de3161750980468df377dc1d)
- modeling_hyperclovax.py +11 -3
modeling_hyperclovax.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# coding=utf-8
|
| 2 |
# This file was created for the HyperCLOVA X SEED 14B Think architecture.
|
| 3 |
-
# partially copied and modified from
|
|
|
|
| 4 |
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 5 |
#
|
| 6 |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
@@ -43,7 +44,14 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
|
|
| 43 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 44 |
from transformers.processing_utils import Unpack
|
| 45 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 46 |
-
from transformers.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
from .configuration_hyperclovax import HyperCLOVAXConfig
|
| 48 |
if is_torch_flex_attn_available():
|
| 49 |
from torch.nn.attention.flex_attention import BlockMask
|
|
@@ -620,7 +628,7 @@ class HyperCLOVAXModel(HyperCLOVAXPreTrainedModel):
|
|
| 620 |
return causal_mask
|
| 621 |
|
| 622 |
|
| 623 |
-
class KwargsForCausalLM(FlashAttentionKwargs,
|
| 624 |
|
| 625 |
|
| 626 |
@auto_docstring
|
|
|
|
| 1 |
# coding=utf-8
|
| 2 |
# This file was created for the HyperCLOVA X SEED 14B Think architecture.
|
| 3 |
+
# partially copied and modified from
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama/modeling_llama.py
|
| 5 |
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 6 |
#
|
| 7 |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
|
|
| 44 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 45 |
from transformers.processing_utils import Unpack
|
| 46 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 47 |
+
from transformers.utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
| 48 |
+
try:
|
| 49 |
+
from transformers.utils import LossKwargs
|
| 50 |
+
loss_kwargs_class = LossKwargs
|
| 51 |
+
except ImportError:
|
| 52 |
+
from transformers.utils import TransformersKwargs
|
| 53 |
+
loss_kwargs_class = TransformersKwargs
|
| 54 |
+
|
| 55 |
from .configuration_hyperclovax import HyperCLOVAXConfig
|
| 56 |
if is_torch_flex_attn_available():
|
| 57 |
from torch.nn.attention.flex_attention import BlockMask
|
|
|
|
| 628 |
return causal_mask
|
| 629 |
|
| 630 |
|
| 631 |
+
class KwargsForCausalLM(FlashAttentionKwargs, loss_kwargs_class): ...
|
| 632 |
|
| 633 |
|
| 634 |
@auto_docstring
|