Update new model
Browse files- modeling_phi.py +54 -38
modeling_phi.py
CHANGED
|
@@ -24,11 +24,14 @@ try:
|
|
| 24 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 25 |
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
| 26 |
from flash_attn.ops.fused_dense import FusedDense
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
pad_input, unpad_input = None, None
|
| 29 |
FlashRotaryEmbedding = None
|
| 30 |
FlashSelfAttention, FlashCrossAttention = None, None
|
| 31 |
FusedDense = None
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
@dataclass
|
|
@@ -525,7 +528,7 @@ class MHA(nn.Module):
|
|
| 525 |
softmax_scale: Optional[float] = None,
|
| 526 |
layer_idx: Optional[int] = None,
|
| 527 |
return_residual: bool = False,
|
| 528 |
-
checkpointing: bool =
|
| 529 |
) -> None:
|
| 530 |
super().__init__()
|
| 531 |
|
|
@@ -607,7 +610,7 @@ class MHA(nn.Module):
|
|
| 607 |
|
| 608 |
if self.checkpointing:
|
| 609 |
attn_output = torch.utils.checkpoint.checkpoint(
|
| 610 |
-
self.inner_attn, qkv, cu_seqlens
|
| 611 |
)
|
| 612 |
else:
|
| 613 |
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
|
@@ -616,7 +619,7 @@ class MHA(nn.Module):
|
|
| 616 |
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
| 617 |
|
| 618 |
if self.checkpointing:
|
| 619 |
-
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=
|
| 620 |
|
| 621 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
| 622 |
|
|
@@ -669,11 +672,12 @@ class MHA(nn.Module):
|
|
| 669 |
self.inner_cross_attn,
|
| 670 |
q,
|
| 671 |
kv,
|
| 672 |
-
causal
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
cu_seqlens_k
|
| 676 |
-
max_seqlen_k
|
|
|
|
| 677 |
)
|
| 678 |
else:
|
| 679 |
attn_output = self.inner_cross_attn(
|
|
@@ -697,8 +701,9 @@ class MHA(nn.Module):
|
|
| 697 |
self.inner_cross_attn,
|
| 698 |
q,
|
| 699 |
kv,
|
| 700 |
-
|
| 701 |
-
|
|
|
|
| 702 |
)
|
| 703 |
|
| 704 |
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
|
@@ -835,7 +840,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 835 |
|
| 836 |
config_class = PhiConfig
|
| 837 |
base_model_prefix = "transformer"
|
| 838 |
-
supports_gradient_checkpointing =
|
| 839 |
_no_split_modules = ["ParallelBlock"]
|
| 840 |
|
| 841 |
def __init__(self, *inputs, **kwargs) -> None:
|
|
@@ -862,20 +867,20 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 862 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 863 |
**kwargs,
|
| 864 |
) -> Dict[str, Any]:
|
| 865 |
-
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
else:
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
|
| 880 |
return {
|
| 881 |
"input_ids": input_ids,
|
|
@@ -891,17 +896,19 @@ class PhiModel(PhiPreTrainedModel):
|
|
| 891 |
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
| 892 |
|
| 893 |
def __init__(self, config: PhiConfig) -> None:
|
|
|
|
|
|
|
| 894 |
super().__init__(config)
|
| 895 |
|
| 896 |
self.embd = Embedding(config)
|
| 897 |
self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
|
| 898 |
-
self.gradient_checkpointing =
|
| 899 |
self.post_init()
|
| 900 |
|
| 901 |
-
def get_input_embeddings(self):
|
| 902 |
-
return self.embd
|
| 903 |
|
| 904 |
-
def set_input_embeddings(self, new_embeddings) -> None:
|
| 905 |
self.embd.wte = new_embeddings
|
| 906 |
|
| 907 |
def forward(
|
|
@@ -919,11 +926,20 @@ class PhiModel(PhiPreTrainedModel):
|
|
| 919 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 920 |
|
| 921 |
for layer in self.h:
|
| 922 |
-
|
| 923 |
-
hidden_states
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
|
| 928 |
return hidden_states
|
| 929 |
|
|
@@ -947,10 +963,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|
| 947 |
|
| 948 |
self.post_init()
|
| 949 |
|
| 950 |
-
def get_output_embeddings(self):
|
| 951 |
-
return self.lm_head
|
| 952 |
|
| 953 |
-
def set_output_embeddings(self, new_embeddings) -> None:
|
| 954 |
self.lm_head.linear = new_embeddings
|
| 955 |
|
| 956 |
def forward(
|
|
|
|
| 24 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 25 |
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
| 26 |
from flash_attn.ops.fused_dense import FusedDense
|
| 27 |
+
print("Using Flash Attention!")
|
| 28 |
+
except Exception as exc:
|
| 29 |
+
print(exc)
|
| 30 |
pad_input, unpad_input = None, None
|
| 31 |
FlashRotaryEmbedding = None
|
| 32 |
FlashSelfAttention, FlashCrossAttention = None, None
|
| 33 |
FusedDense = None
|
| 34 |
+
print("Not using Flash Attention!")
|
| 35 |
|
| 36 |
|
| 37 |
@dataclass
|
|
|
|
| 528 |
softmax_scale: Optional[float] = None,
|
| 529 |
layer_idx: Optional[int] = None,
|
| 530 |
return_residual: bool = False,
|
| 531 |
+
checkpointing: bool = True,
|
| 532 |
) -> None:
|
| 533 |
super().__init__()
|
| 534 |
|
|
|
|
| 610 |
|
| 611 |
if self.checkpointing:
|
| 612 |
attn_output = torch.utils.checkpoint.checkpoint(
|
| 613 |
+
self.inner_attn, qkv, None, cu_seqlens, max_seqlen, use_reentrant=False
|
| 614 |
)
|
| 615 |
else:
|
| 616 |
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
|
|
|
| 619 |
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
| 620 |
|
| 621 |
if self.checkpointing:
|
| 622 |
+
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, None, key_padding_mask, use_reentrant=False)
|
| 623 |
|
| 624 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
| 625 |
|
|
|
|
| 672 |
self.inner_cross_attn,
|
| 673 |
q,
|
| 674 |
kv,
|
| 675 |
+
causal,
|
| 676 |
+
cu_seqlens_q,
|
| 677 |
+
max_seqlen_q,
|
| 678 |
+
cu_seqlens_k,
|
| 679 |
+
max_seqlen_k,
|
| 680 |
+
use_reentrant=False,
|
| 681 |
)
|
| 682 |
else:
|
| 683 |
attn_output = self.inner_cross_attn(
|
|
|
|
| 701 |
self.inner_cross_attn,
|
| 702 |
q,
|
| 703 |
kv,
|
| 704 |
+
causal,
|
| 705 |
+
key_padding_mask,
|
| 706 |
+
use_reentrant=False,
|
| 707 |
)
|
| 708 |
|
| 709 |
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
|
|
|
| 840 |
|
| 841 |
config_class = PhiConfig
|
| 842 |
base_model_prefix = "transformer"
|
| 843 |
+
supports_gradient_checkpointing = True
|
| 844 |
_no_split_modules = ["ParallelBlock"]
|
| 845 |
|
| 846 |
def __init__(self, *inputs, **kwargs) -> None:
|
|
|
|
| 867 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 868 |
**kwargs,
|
| 869 |
) -> Dict[str, Any]:
|
| 870 |
+
# if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
| 871 |
+
# past_key_values = InferenceParams(
|
| 872 |
+
# max_seqlen=self.config.n_positions,
|
| 873 |
+
# max_batch_size=input_ids.shape[0],
|
| 874 |
+
# seqlen_offset=0,
|
| 875 |
+
# batch_size_offset=0,
|
| 876 |
+
# key_value_memory_dict={},
|
| 877 |
+
# lengths_per_sample=None,
|
| 878 |
+
# )
|
| 879 |
+
# else:
|
| 880 |
+
# # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
| 881 |
+
# past_key_values.seqlen_offset = input_ids.shape[1] - 1
|
| 882 |
+
# input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 883 |
+
# attention_mask = attention_mask[:, -1].unsqueeze(-1)
|
| 884 |
|
| 885 |
return {
|
| 886 |
"input_ids": input_ids,
|
|
|
|
| 896 |
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
| 897 |
|
| 898 |
def __init__(self, config: PhiConfig) -> None:
|
| 899 |
+
config.flash_attn = True
|
| 900 |
+
config.flash_rotary = True
|
| 901 |
super().__init__(config)
|
| 902 |
|
| 903 |
self.embd = Embedding(config)
|
| 904 |
self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
|
| 905 |
+
self.gradient_checkpointing = True
|
| 906 |
self.post_init()
|
| 907 |
|
| 908 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 909 |
+
return self.embd.wte
|
| 910 |
|
| 911 |
+
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
| 912 |
self.embd.wte = new_embeddings
|
| 913 |
|
| 914 |
def forward(
|
|
|
|
| 926 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 927 |
|
| 928 |
for layer in self.h:
|
| 929 |
+
if self.gradient_checkpointing:
|
| 930 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 931 |
+
layer.__call__,
|
| 932 |
+
hidden_states,
|
| 933 |
+
past_key_values,
|
| 934 |
+
attention_mask,
|
| 935 |
+
use_reentrant=False,
|
| 936 |
+
)
|
| 937 |
+
else:
|
| 938 |
+
hidden_states = layer(
|
| 939 |
+
hidden_states,
|
| 940 |
+
past_key_values=past_key_values,
|
| 941 |
+
attention_mask=attention_mask,
|
| 942 |
+
)
|
| 943 |
|
| 944 |
return hidden_states
|
| 945 |
|
|
|
|
| 963 |
|
| 964 |
self.post_init()
|
| 965 |
|
| 966 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 967 |
+
return self.lm_head.linear
|
| 968 |
|
| 969 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
| 970 |
self.lm_head.linear = new_embeddings
|
| 971 |
|
| 972 |
def forward(
|