ford442 commited on
Commit
da66ad7
·
verified ·
1 Parent(s): 7db287a

Update fish_speech/models/text2semantic/llama.py

Browse files
fish_speech/models/text2semantic/llama.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import dataclasses
2
  import json
3
  import math
@@ -862,7 +863,8 @@ class Attention(nn.Module):
862
 
863
  L, S = query.size(-2), key.size(-2)
864
  scale_factor = 1 / math.sqrt(query.size(-1))
865
- attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
 
866
 
867
  if attn_mask is not None:
868
  if attn_mask.dtype == torch.bool:
@@ -938,3 +940,4 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
938
 
939
  x_out2 = x_out2.flatten(3)
940
  return x_out2.type_as(x)
 
 
1
+
2
  import dataclasses
3
  import json
4
  import math
 
863
 
864
  L, S = query.size(-2), key.size(-2)
865
  scale_factor = 1 / math.sqrt(query.size(-1))
866
+ # FIX: Use new_zeros to avoid passing device object to torch.zeros which causes torch.compile error
867
+ attn_bias = query.new_zeros(1, 1, L, S)
868
 
869
  if attn_mask is not None:
870
  if attn_mask.dtype == torch.bool:
 
940
 
941
  x_out2 = x_out2.flatten(3)
942
  return x_out2.type_as(x)
943
+