Commit
·
0f3418e
1
Parent(s):
10aca20
Got model running, but results are incorrect
Browse files- attention.py +3 -6
- config.json +2 -2
- phi2_configuration.py +18 -18
- phi2_model.py +1 -1
attention.py
CHANGED
|
@@ -28,7 +28,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 28 |
d_rotary: int,
|
| 29 |
rotary_base: float = 10000.0,
|
| 30 |
initial_cos_sin_cache_len: int = 2048,
|
| 31 |
-
device: torch.device
|
| 32 |
) -> None:
|
| 33 |
super().__init__()
|
| 34 |
self.d_rotary = d_rotary
|
|
@@ -52,7 +52,6 @@ class RotaryEmbedding(nn.Module):
|
|
| 52 |
torch.arange(
|
| 53 |
start=0,
|
| 54 |
end=self.d_rotary,
|
| 55 |
-
step=2,
|
| 56 |
device=self.device,
|
| 57 |
dtype=self.dtype,
|
| 58 |
) / self.d_rotary
|
|
@@ -61,8 +60,8 @@ class RotaryEmbedding(nn.Module):
|
|
| 61 |
# torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
|
| 62 |
# TODO: does this matter if I'm disabling torch.autocast?
|
| 63 |
m_theta_i = torch.outer(m, theta_i)
|
| 64 |
-
self._cos_cached = torch.cos(m_theta_i).to(self.dtype)
|
| 65 |
-
self._sin_cached = torch.sin(m_theta_i).to(self.dtype)
|
| 66 |
|
| 67 |
# TODO: scale_base caching is labelled as not yet done in Phi2
|
| 68 |
"""
|
|
@@ -108,8 +107,6 @@ class RotaryEmbedding(nn.Module):
|
|
| 108 |
if (
|
| 109 |
not self._max_seqlen
|
| 110 |
or self._max_seqlen < x.shape[1] + seqlen_offset
|
| 111 |
-
or self._cos_cached.device != x.device
|
| 112 |
-
or self._cos_cached.dtype != x.dtype
|
| 113 |
or (self.training and self._cos_cached.is_inference())
|
| 114 |
):
|
| 115 |
self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
|
|
|
|
| 28 |
d_rotary: int,
|
| 29 |
rotary_base: float = 10000.0,
|
| 30 |
initial_cos_sin_cache_len: int = 2048,
|
| 31 |
+
device: torch.device = "cuda",
|
| 32 |
) -> None:
|
| 33 |
super().__init__()
|
| 34 |
self.d_rotary = d_rotary
|
|
|
|
| 52 |
torch.arange(
|
| 53 |
start=0,
|
| 54 |
end=self.d_rotary,
|
|
|
|
| 55 |
device=self.device,
|
| 56 |
dtype=self.dtype,
|
| 57 |
) / self.d_rotary
|
|
|
|
| 60 |
# torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
|
| 61 |
# TODO: does this matter if I'm disabling torch.autocast?
|
| 62 |
m_theta_i = torch.outer(m, theta_i)
|
| 63 |
+
self._cos_cached = torch.cos(m_theta_i).to(self.dtype).to(self.device)
|
| 64 |
+
self._sin_cached = torch.sin(m_theta_i).to(self.dtype).to(self.device)
|
| 65 |
|
| 66 |
# TODO: scale_base caching is labelled as not yet done in Phi2
|
| 67 |
"""
|
|
|
|
| 107 |
if (
|
| 108 |
not self._max_seqlen
|
| 109 |
or self._max_seqlen < x.shape[1] + seqlen_offset
|
|
|
|
|
|
|
| 110 |
or (self.training and self._cos_cached.is_inference())
|
| 111 |
):
|
| 112 |
self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
|
config.json
CHANGED
|
@@ -17,8 +17,8 @@
|
|
| 17 |
"vocab_chunk_for_gpu_efficiency": 64,
|
| 18 |
"initial_cos_sin_cache_len": 2048,
|
| 19 |
"d_embedding": 2560,
|
| 20 |
-
"
|
| 21 |
-
"
|
| 22 |
"use_flash_attn": false,
|
| 23 |
"use_flash_rotary": false,
|
| 24 |
"use_fused_dense": false,
|
|
|
|
| 17 |
"vocab_chunk_for_gpu_efficiency": 64,
|
| 18 |
"initial_cos_sin_cache_len": 2048,
|
| 19 |
"d_embedding": 2560,
|
| 20 |
+
"n_attn_blocks": 32,
|
| 21 |
+
"n_attn_heads": 32,
|
| 22 |
"use_flash_attn": false,
|
| 23 |
"use_flash_rotary": false,
|
| 24 |
"use_fused_dense": false,
|
phi2_configuration.py
CHANGED
|
@@ -8,27 +8,27 @@ class Phi2Config(PretrainedConfig):
|
|
| 8 |
"max_position_embeddings": "initial_cos_sin_cache_len",
|
| 9 |
"hidden_size": "d_embedding",
|
| 10 |
"num_attention_heads": "n_attn_heads",
|
| 11 |
-
"num_hidden_layers": "
|
| 12 |
}
|
| 13 |
|
| 14 |
def __init__(
|
| 15 |
self,
|
| 16 |
-
vocab_size: int
|
| 17 |
-
vocab_chunk_for_gpu_efficiency: int
|
| 18 |
-
initial_cos_sin_cache_len: int
|
| 19 |
-
d_embedding: int
|
| 20 |
-
|
| 21 |
-
n_attn_heads: int
|
| 22 |
-
use_flash_attn: bool
|
| 23 |
-
use_flash_rotary: bool
|
| 24 |
-
use_fused_dense: bool
|
| 25 |
-
attn_pdrop: float
|
| 26 |
-
embd_pdrop: float
|
| 27 |
-
resid_pdrop: float
|
| 28 |
-
layer_norm_epsilon: float
|
| 29 |
-
weight_initialization_range: float
|
| 30 |
-
tie_word_embeddings: bool
|
| 31 |
-
checkpointing: bool
|
| 32 |
**kwargs
|
| 33 |
) -> None:
|
| 34 |
self.vocab_size = (
|
|
@@ -38,7 +38,7 @@ class Phi2Config(PretrainedConfig):
|
|
| 38 |
)
|
| 39 |
self.initial_cos_sin_cache_len = initial_cos_sin_cache_len
|
| 40 |
self.d_embedding = d_embedding
|
| 41 |
-
self.
|
| 42 |
self.n_attn_heads = n_attn_heads
|
| 43 |
self.use_flash_attn = use_flash_attn
|
| 44 |
self.use_flash_rotary = use_flash_rotary
|
|
|
|
| 8 |
"max_position_embeddings": "initial_cos_sin_cache_len",
|
| 9 |
"hidden_size": "d_embedding",
|
| 10 |
"num_attention_heads": "n_attn_heads",
|
| 11 |
+
"num_hidden_layers": "n_attn_blocks",
|
| 12 |
}
|
| 13 |
|
| 14 |
def __init__(
|
| 15 |
self,
|
| 16 |
+
vocab_size: int, # this includes the extra tokens included by Phi2 in tokenizer_config.json
|
| 17 |
+
vocab_chunk_for_gpu_efficiency: int,
|
| 18 |
+
initial_cos_sin_cache_len: int,
|
| 19 |
+
d_embedding: int,
|
| 20 |
+
n_attn_blocks: int,
|
| 21 |
+
n_attn_heads: int,
|
| 22 |
+
use_flash_attn: bool,
|
| 23 |
+
use_flash_rotary: bool,
|
| 24 |
+
use_fused_dense: bool,
|
| 25 |
+
attn_pdrop: float,
|
| 26 |
+
embd_pdrop: float,
|
| 27 |
+
resid_pdrop: float,
|
| 28 |
+
layer_norm_epsilon: float,
|
| 29 |
+
weight_initialization_range: float,
|
| 30 |
+
tie_word_embeddings: bool, # whether embedding weights are shared between the encoder and decoder
|
| 31 |
+
checkpointing: bool, # whether to use gradient checkpointing to reduce memory usage (I think)
|
| 32 |
**kwargs
|
| 33 |
) -> None:
|
| 34 |
self.vocab_size = (
|
|
|
|
| 38 |
)
|
| 39 |
self.initial_cos_sin_cache_len = initial_cos_sin_cache_len
|
| 40 |
self.d_embedding = d_embedding
|
| 41 |
+
self.n_attn_blocks = n_attn_blocks
|
| 42 |
self.n_attn_heads = n_attn_heads
|
| 43 |
self.use_flash_attn = use_flash_attn
|
| 44 |
self.use_flash_rotary = use_flash_rotary
|
phi2_model.py
CHANGED
|
@@ -106,7 +106,7 @@ class Phi2Model(Phi2PreTrainedModel):
|
|
| 106 |
use_fused_dense=config.use_fused_dense,
|
| 107 |
checkpointing=config.checkpointing,
|
| 108 |
)
|
| 109 |
-
for i in range(config.
|
| 110 |
])
|
| 111 |
self.gradient_checkpointing_disable() # https://github.com/cybertronai/gradient-checkpointing - I think this is turned off due to flash attention?
|
| 112 |
self.post_init() # calls self._init_weights() for all modules
|
|
|
|
| 106 |
use_fused_dense=config.use_fused_dense,
|
| 107 |
checkpointing=config.checkpointing,
|
| 108 |
)
|
| 109 |
+
for i in range(config.n_attn_blocks)
|
| 110 |
])
|
| 111 |
self.gradient_checkpointing_disable() # https://github.com/cybertronai/gradient-checkpointing - I think this is turned off due to flash attention?
|
| 112 |
self.post_init() # calls self._init_weights() for all modules
|