Update attention.py
Browse files- attention.py +1 -1
attention.py
CHANGED
|
@@ -87,7 +87,7 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
|
|
| 87 |
|
| 88 |
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
| 89 |
try:
|
| 90 |
-
from
|
| 91 |
except:
|
| 92 |
raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
|
| 93 |
check_valid_inputs(query, key, value)
|
|
|
|
| 87 |
|
| 88 |
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
| 89 |
try:
|
| 90 |
+
from flash_attn import flash_attn_triton
|
| 91 |
except:
|
| 92 |
raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
|
| 93 |
check_valid_inputs(query, key, value)
|