InteractiveOmni-8B / modeling_voicelm.py
sensefvg's picture
upload initial model
b3f3294 verified
# --------------------------------------------------------
# SenseTime
# Copyright (c) 2025 SenseTime
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from typing import List
import math
import torch
from torch import nn
from transformers import Qwen2ForCausalLM
from transformers import PreTrainedModel
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
from .configuration_voicelm import VoiceLMConfig
class Qwen2Encoder(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.model = Qwen2ForCausalLM(config)
pass
def forward_one_step(self, xs, masks, cache=None):
input_masks = masks[:, -1, :]
outs = self.model(
inputs_embeds=xs,
attention_mask=input_masks,
output_hidden_states=True,
return_dict=True,
use_cache=True,
past_key_values=cache,
)
xs = outs.hidden_states[-1]
new_cache = outs.past_key_values
return xs, new_cache
class VoiceLM(PreTrainedModel):
"""
voicelm model
"""
def __init__(self, config: VoiceLMConfig):
super().__init__(config)
self.llm_input_size = config.llm_input_size
self.llm_output_size = config.llm_output_size
self.speech_token_size = config.speech_token_size # 6561
self.sampling_config = config.sampling_config
self.sos_eos = 0
self.task_id = 1
self.fill_token = 2
self.llm_embedding = torch.nn.Embedding(2, config.llm_input_size)
self.llm = Qwen2Encoder(config.llm_config)
self.llm_decoder = nn.Linear(config.llm_output_size, config.speech_token_size + 3)
# speech token embedding (6564, 896)
self.speech_embedding = torch.nn.Embedding(
config.speech_token_size + 3,
config.llm_input_size,
)
pass
# Repetition Aware Sampling in VALL-E 2
def ras_sampling(self, weighted_scores:torch.Tensor, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
if rep_num >= win_size * tau_r:
top_ids = self.random_sampling(weighted_scores, decoded_tokens, sampling)
return top_ids
def nucleus_sampling(self, weighted_scores:torch.Tensor, top_p=0.8, top_k=25):
prob, indices = [], []
cum_prob = 0.0
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
for i in range(len(sorted_idx)):
# sampling both top-p and numbers.
if cum_prob < top_p and len(prob) < top_k:
cum_prob += sorted_value[i]
prob.append(sorted_value[i])
indices.append(sorted_idx[i])
else:
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
top_ids = indices[prob.multinomial(1, replacement=True)]
return top_ids
def random_sampling(self, weighted_scores:torch.Tensor, decoded_tokens, sampling):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
return top_ids
def sampling_ids(
self,
weighted_scores: torch.Tensor,
decoded_tokens: List,
sampling: int,
ignore_eos: bool = True,
):
num_trials, max_trials = 0, 100
while True:
top_ids = self.ras_sampling(weighted_scores, decoded_tokens, sampling, **self.sampling_config)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
break
num_trials += 1
if num_trials > max_trials:
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
return top_ids
@torch.inference_mode()
def inference_bistream(
self,
input_feature: torch.Tensor,
target_text_feature: torch.Tensor,
sampling: int = 25,
mix_ratio: List[int] = [5, 25],
):
text_token_len = target_text_feature.size(1)
# 1. prepare input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
lm_input = torch.concat([sos_eos_emb, input_feature], dim=1)
# 2. iterate text
out_tokens = []
return_out_tokens = []
cache = None
text_cache = target_text_feature
next_fill_index = -1
for j in range(int(math.floor((text_token_len) / mix_ratio[0] ))):
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == (1 + input_feature.size(1))):
logger.info('get fill token, need to append more text token')
if text_cache.size(1) >= mix_ratio[0]:
lm_input_text = text_cache[:, :mix_ratio[0]]
logger.info('append {} text token'.format(lm_input_text.size(1)))
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
lm_input = lm_input_text
else:
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
text_cache = text_cache[:, mix_ratio[0]:]
else:
logger.info('not enough text token to decode, wait for more')
continue
while True:
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
top_ids = self.speech_token_size + 2
next_fill_index += (mix_ratio[1] + 1)
else:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
if top_ids == self.speech_token_size + 2:
next_fill_index = len(out_tokens) + mix_ratio[1] + 1
logger.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size + 2:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
# yield top_ids
return_out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
# 3. final decode
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
logger.info('no more text token, decode until met eos')
while True:
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
# in stream mode, yield token one by one
# yield top_ids
return_out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
return return_out_tokens