ford442 commited on
Commit
7db287a
·
verified ·
1 Parent(s): 1988507

Update fish_speech/models/text2semantic/inference.py

Browse files
fish_speech/models/text2semantic/inference.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import queue
3
  import threading
@@ -135,7 +136,8 @@ def decode_one_token_ar(
135
  layer.attention.kv_cache.k_cache.fill_(0)
136
  layer.attention.kv_cache.v_cache.fill_(0)
137
 
138
- input_pos = torch.tensor([0], dtype=torch.long).to(hidden_states.device)
 
139
  model.forward_generate_fast(hidden_states, input_pos)
140
  a = codebooks[0] - model.tokenizer.semantic_begin_id
141
  a[a < 0] = 0
@@ -143,9 +145,8 @@ def decode_one_token_ar(
143
  codebooks.append(a)
144
 
145
  for codebook_idx in range(1, model.config.num_codebooks):
146
- input_pos = torch.tensor(
147
- [codebook_idx], dtype=torch.long
148
- ).to(hidden_states.device)
149
  logits = model.forward_generate_fast(hidden_states, input_pos)
150
 
151
  short_logits = logits[:, :, :1024]
@@ -704,3 +705,4 @@ def main(
704
 
705
  if __name__ == "__main__":
706
  main()
 
 
1
+
2
  import os
3
  import queue
4
  import threading
 
136
  layer.attention.kv_cache.k_cache.fill_(0)
137
  layer.attention.kv_cache.v_cache.fill_(0)
138
 
139
+ # FIX: Use new_zeros to avoid torch.compile issues with device argument
140
+ input_pos = hidden_states.new_zeros((1,), dtype=torch.long)
141
  model.forward_generate_fast(hidden_states, input_pos)
142
  a = codebooks[0] - model.tokenizer.semantic_begin_id
143
  a[a < 0] = 0
 
145
  codebooks.append(a)
146
 
147
  for codebook_idx in range(1, model.config.num_codebooks):
148
+ # FIX: Use new_full to avoid torch.compile issues with device argument
149
+ input_pos = hidden_states.new_full((1,), codebook_idx, dtype=torch.long)
 
150
  logits = model.forward_generate_fast(hidden_states, input_pos)
151
 
152
  short_logits = logits[:, :, :1024]
 
705
 
706
  if __name__ == "__main__":
707
  main()
708
+