bigmoyan KamioMitsuzu commited on
Commit
4b4b7bf
·
verified ·
1 Parent(s): a574f67

[Fix] Fix code bugs in modeling for multiturn inference (#17)

Browse files

- [Fix] Fix code bugs in modeling for multiturn inference (c60abbbdb143ceda192487c06d03b49ba231953f)


Co-authored-by: Mashiro <[email protected]>

Files changed (1) hide show
  1. modeling_moonshot_kimia.py +3 -4
modeling_moonshot_kimia.py CHANGED
@@ -685,14 +685,13 @@ class MoonshotKimiaModel(Qwen2PreTrainedModel):
685
  .to(torch.cuda.current_device())
686
  .to(whisper_dtype)
687
  )
688
- for (seg_idx, start_idx), (_, end_idx) in zip(
 
689
  media_start_idx, media_end_idx
690
- ):
691
- # assert whisper_emb.shape[1] == end_idx - (start_idx + 1)
692
 
693
  feat_len = end_idx - (start_idx + 1)
694
  whisper_input_feature_i = whisper_input_feature[seg_idx].squeeze(0)
695
- assert feat_len == is_continuous_mask[seg_idx].sum()
696
  expanded_whisper[start_idx + 1 : end_idx, :] = (
697
  whisper_input_feature_i[:feat_len, :]
698
  )
 
685
  .to(torch.cuda.current_device())
686
  .to(whisper_dtype)
687
  )
688
+ assert (media_end_idx - media_start_idx).sum() - media_start_idx.shape[0] == is_continuous_mask.sum()
689
+ for seg_idx, ((batch_idx, start_idx), (_, end_idx)) in enumerate(zip(
690
  media_start_idx, media_end_idx
691
+ )):
 
692
 
693
  feat_len = end_idx - (start_idx + 1)
694
  whisper_input_feature_i = whisper_input_feature[seg_idx].squeeze(0)
 
695
  expanded_whisper[start_idx + 1 : end_idx, :] = (
696
  whisper_input_feature_i[:feat_len, :]
697
  )