Spaces:
Paused
Paused
Update fish_speech/models/text2semantic/inference.py
Browse files
fish_speech/models/text2semantic/inference.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import queue
|
| 3 |
import threading
|
|
@@ -135,7 +136,8 @@ def decode_one_token_ar(
|
|
| 135 |
layer.attention.kv_cache.k_cache.fill_(0)
|
| 136 |
layer.attention.kv_cache.v_cache.fill_(0)
|
| 137 |
|
| 138 |
-
|
|
|
|
| 139 |
model.forward_generate_fast(hidden_states, input_pos)
|
| 140 |
a = codebooks[0] - model.tokenizer.semantic_begin_id
|
| 141 |
a[a < 0] = 0
|
|
@@ -143,9 +145,8 @@ def decode_one_token_ar(
|
|
| 143 |
codebooks.append(a)
|
| 144 |
|
| 145 |
for codebook_idx in range(1, model.config.num_codebooks):
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
).to(hidden_states.device)
|
| 149 |
logits = model.forward_generate_fast(hidden_states, input_pos)
|
| 150 |
|
| 151 |
short_logits = logits[:, :, :1024]
|
|
@@ -704,3 +705,4 @@ def main(
|
|
| 704 |
|
| 705 |
if __name__ == "__main__":
|
| 706 |
main()
|
|
|
|
|
|
| 1 |
+
|
| 2 |
import os
|
| 3 |
import queue
|
| 4 |
import threading
|
|
|
|
| 136 |
layer.attention.kv_cache.k_cache.fill_(0)
|
| 137 |
layer.attention.kv_cache.v_cache.fill_(0)
|
| 138 |
|
| 139 |
+
# FIX: Use new_zeros to avoid torch.compile issues with device argument
|
| 140 |
+
input_pos = hidden_states.new_zeros((1,), dtype=torch.long)
|
| 141 |
model.forward_generate_fast(hidden_states, input_pos)
|
| 142 |
a = codebooks[0] - model.tokenizer.semantic_begin_id
|
| 143 |
a[a < 0] = 0
|
|
|
|
| 145 |
codebooks.append(a)
|
| 146 |
|
| 147 |
for codebook_idx in range(1, model.config.num_codebooks):
|
| 148 |
+
# FIX: Use new_full to avoid torch.compile issues with device argument
|
| 149 |
+
input_pos = hidden_states.new_full((1,), codebook_idx, dtype=torch.long)
|
|
|
|
| 150 |
logits = model.forward_generate_fast(hidden_states, input_pos)
|
| 151 |
|
| 152 |
short_logits = logits[:, :, :1024]
|
|
|
|
| 705 |
|
| 706 |
if __name__ == "__main__":
|
| 707 |
main()
|
| 708 |
+
|