Update modeling_minicpmo.py
Browse files- modeling_minicpmo.py +10 -6
modeling_minicpmo.py
CHANGED
|
@@ -377,10 +377,12 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 377 |
else:
|
| 378 |
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
| 379 |
|
|
|
|
|
|
|
| 380 |
vision_hidden_states = [
|
| 381 |
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
| 382 |
]
|
| 383 |
-
|
| 384 |
bs = len(data["input_ids"])
|
| 385 |
for i in range(bs):
|
| 386 |
cur_vs_hs = vision_hidden_states[i]
|
|
@@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 392 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
| 393 |
).to(vllm_embedding.device)
|
| 394 |
|
| 395 |
-
cur_vllm_emb.
|
| 396 |
0,
|
| 397 |
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
| 398 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
| 399 |
)
|
|
|
|
| 400 |
elif self.training:
|
| 401 |
-
|
| 402 |
|
| 403 |
-
return
|
| 404 |
|
| 405 |
def get_audio_embedding_streaming(self, data):
|
| 406 |
r"""
|
|
@@ -595,7 +598,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 595 |
elif self.training:
|
| 596 |
for i in range(bs):
|
| 597 |
# dummy audio_embeddings
|
| 598 |
-
input_embeddings
|
| 599 |
|
| 600 |
return input_embeddings
|
| 601 |
|
|
@@ -751,7 +754,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 751 |
input_ids=None,
|
| 752 |
pixel_values=None,
|
| 753 |
tgt_sizes=None,
|
| 754 |
-
audio_features=
|
| 755 |
audio_feature_lens=None,
|
| 756 |
image_bound=None,
|
| 757 |
audio_bounds=None,
|
|
@@ -2655,6 +2658,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
|
| 2655 |
"""
|
| 2656 |
|
| 2657 |
config_class = ConditionalChatTTSConfig
|
|
|
|
| 2658 |
|
| 2659 |
def __init__(self, config: ConditionalChatTTSConfig):
|
| 2660 |
super().__init__(config)
|
|
|
|
| 377 |
else:
|
| 378 |
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
| 379 |
|
| 380 |
+
new_vllm_embedding = vllm_embedding.clone()
|
| 381 |
+
|
| 382 |
vision_hidden_states = [
|
| 383 |
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
| 384 |
]
|
| 385 |
+
|
| 386 |
bs = len(data["input_ids"])
|
| 387 |
for i in range(bs):
|
| 388 |
cur_vs_hs = vision_hidden_states[i]
|
|
|
|
| 394 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
| 395 |
).to(vllm_embedding.device)
|
| 396 |
|
| 397 |
+
new_vllm_embedding[i] = cur_vllm_emb.scatter(
|
| 398 |
0,
|
| 399 |
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
| 400 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
| 401 |
)
|
| 402 |
+
|
| 403 |
elif self.training:
|
| 404 |
+
new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0
|
| 405 |
|
| 406 |
+
return new_vllm_embedding, vision_hidden_states
|
| 407 |
|
| 408 |
def get_audio_embedding_streaming(self, data):
|
| 409 |
r"""
|
|
|
|
| 598 |
elif self.training:
|
| 599 |
for i in range(bs):
|
| 600 |
# dummy audio_embeddings
|
| 601 |
+
input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0
|
| 602 |
|
| 603 |
return input_embeddings
|
| 604 |
|
|
|
|
| 754 |
input_ids=None,
|
| 755 |
pixel_values=None,
|
| 756 |
tgt_sizes=None,
|
| 757 |
+
audio_features=[],
|
| 758 |
audio_feature_lens=None,
|
| 759 |
image_bound=None,
|
| 760 |
audio_bounds=None,
|
|
|
|
| 2658 |
"""
|
| 2659 |
|
| 2660 |
config_class = ConditionalChatTTSConfig
|
| 2661 |
+
_no_split_modules = []
|
| 2662 |
|
| 2663 |
def __init__(self, config: ConditionalChatTTSConfig):
|
| 2664 |
super().__init__(config)
|