Spaces:
Running
on
Zero
Running
on
Zero
Update fish_speech/models/text2semantic/llama.py
Browse files
fish_speech/models/text2semantic/llama.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import dataclasses
|
| 2 |
import json
|
| 3 |
import math
|
|
@@ -862,7 +863,8 @@ class Attention(nn.Module):
|
|
| 862 |
|
| 863 |
L, S = query.size(-2), key.size(-2)
|
| 864 |
scale_factor = 1 / math.sqrt(query.size(-1))
|
| 865 |
-
|
|
|
|
| 866 |
|
| 867 |
if attn_mask is not None:
|
| 868 |
if attn_mask.dtype == torch.bool:
|
|
@@ -938,3 +940,4 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
|
| 938 |
|
| 939 |
x_out2 = x_out2.flatten(3)
|
| 940 |
return x_out2.type_as(x)
|
|
|
|
|
|
| 1 |
+
|
| 2 |
import dataclasses
|
| 3 |
import json
|
| 4 |
import math
|
|
|
|
| 863 |
|
| 864 |
L, S = query.size(-2), key.size(-2)
|
| 865 |
scale_factor = 1 / math.sqrt(query.size(-1))
|
| 866 |
+
# FIX: Use new_zeros to avoid passing device object to torch.zeros which causes torch.compile error
|
| 867 |
+
attn_bias = query.new_zeros(1, 1, L, S)
|
| 868 |
|
| 869 |
if attn_mask is not None:
|
| 870 |
if attn_mask.dtype == torch.bool:
|
|
|
|
| 940 |
|
| 941 |
x_out2 = x_out2.flatten(3)
|
| 942 |
return x_out2.type_as(x)
|
| 943 |
+
|