[Fix] Fix code bugs in modeling for multiturn inference (#17)
Browse files- [Fix] Fix code bugs in modeling for multiturn inference (c60abbbdb143ceda192487c06d03b49ba231953f)
Co-authored-by: Mashiro <[email protected]>
modeling_moonshot_kimia.py
CHANGED
|
@@ -685,14 +685,13 @@ class MoonshotKimiaModel(Qwen2PreTrainedModel):
|
|
| 685 |
.to(torch.cuda.current_device())
|
| 686 |
.to(whisper_dtype)
|
| 687 |
)
|
| 688 |
-
|
|
|
|
| 689 |
media_start_idx, media_end_idx
|
| 690 |
-
):
|
| 691 |
-
# assert whisper_emb.shape[1] == end_idx - (start_idx + 1)
|
| 692 |
|
| 693 |
feat_len = end_idx - (start_idx + 1)
|
| 694 |
whisper_input_feature_i = whisper_input_feature[seg_idx].squeeze(0)
|
| 695 |
-
assert feat_len == is_continuous_mask[seg_idx].sum()
|
| 696 |
expanded_whisper[start_idx + 1 : end_idx, :] = (
|
| 697 |
whisper_input_feature_i[:feat_len, :]
|
| 698 |
)
|
|
|
|
| 685 |
.to(torch.cuda.current_device())
|
| 686 |
.to(whisper_dtype)
|
| 687 |
)
|
| 688 |
+
assert (media_end_idx - media_start_idx).sum() - media_start_idx.shape[0] == is_continuous_mask.sum()
|
| 689 |
+
for seg_idx, ((batch_idx, start_idx), (_, end_idx)) in enumerate(zip(
|
| 690 |
media_start_idx, media_end_idx
|
| 691 |
+
)):
|
|
|
|
| 692 |
|
| 693 |
feat_len = end_idx - (start_idx + 1)
|
| 694 |
whisper_input_feature_i = whisper_input_feature[seg_idx].squeeze(0)
|
|
|
|
| 695 |
expanded_whisper[start_idx + 1 : end_idx, :] = (
|
| 696 |
whisper_input_feature_i[:feat_len, :]
|
| 697 |
)
|