|
|
|
|
|
|
|
|
import logging
|
|
|
from typing import Union, Optional, List
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from torch import nn, Tensor
|
|
|
from transformers import LlamaModel, LlamaConfig
|
|
|
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
|
|
|
|
|
|
from .modules.learned_pos_emb import LearnedPositionEmbeddings
|
|
|
|
|
|
from .modules.cond_enc import T3CondEnc, T3Cond
|
|
|
from .modules.t3_config import T3Config
|
|
|
from .llama_configs import LLAMA_CONFIGS
|
|
|
from .inference.t3_hf_backend import T3HuggingfaceBackend
|
|
|
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class AttrDict(dict):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super(AttrDict, self).__init__(*args, **kwargs)
|
|
|
self.__dict__ = self
|
|
|
|
|
|
|
|
|
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
|
|
|
B = text_tokens.size(0)
|
|
|
assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
|
|
|
assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token"
|
|
|
|
|
|
|
|
|
class T3(nn.Module):
|
|
|
"""
|
|
|
Token-To-Token (T3) TTS model using huggingface transformer models as backbones,
|
|
|
* tokenization, including start / stop tokens are always added externally to this class
|
|
|
* conditioning data like CLAP, emotion, etc are all in a separate file for more modularity
|
|
|
* careful! this class assumes relative positional encoding -- with absolute PE, we would at
|
|
|
least want to reset the position to 0 when speech tokens begin, and optionally use a
|
|
|
different PE embedding space for speech.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, hp=T3Config()):
|
|
|
super().__init__()
|
|
|
self.hp = hp
|
|
|
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
|
|
|
self.tfmr = LlamaModel(self.cfg)
|
|
|
self.dim = self.cfg.hidden_size
|
|
|
self.deepspeed_patch_applied = False
|
|
|
|
|
|
|
|
|
self.cond_enc = T3CondEnc(hp)
|
|
|
self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim)
|
|
|
self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
|
|
|
|
|
|
|
|
|
if hp.input_pos_emb == "learned":
|
|
|
max_text_seq_len = hp.max_text_tokens + 2
|
|
|
self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
|
|
|
|
|
|
max_mel_seq_len = hp.max_speech_tokens + 2 + 2
|
|
|
self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
|
|
|
|
|
|
|
|
|
self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
|
|
|
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
|
|
|
self.compiled = False
|
|
|
|
|
|
@property
|
|
|
def device(self):
|
|
|
return self.speech_head.weight.device
|
|
|
|
|
|
def prepare_conditioning(self, t3_cond: T3Cond):
|
|
|
"""
|
|
|
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
|
|
|
"""
|
|
|
if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
|
|
|
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
|
|
|
self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
|
|
|
return self.cond_enc(t3_cond)
|
|
|
|
|
|
def prepare_input_embeds(
|
|
|
self,
|
|
|
*,
|
|
|
t3_cond: T3Cond,
|
|
|
text_tokens: torch.LongTensor,
|
|
|
speech_tokens: torch.LongTensor,
|
|
|
):
|
|
|
|
|
|
cond_emb = self.prepare_conditioning(t3_cond)
|
|
|
text_emb = self.text_emb(text_tokens)
|
|
|
text_emb[1].zero_()
|
|
|
|
|
|
speech_emb = self.speech_emb(speech_tokens)
|
|
|
if self.hp.input_pos_emb == "learned":
|
|
|
text_emb = text_emb + self.text_pos_emb(text_tokens)
|
|
|
speech_emb = speech_emb + self.speech_pos_emb(speech_tokens)
|
|
|
len_cond = cond_emb.size(1)
|
|
|
|
|
|
if cond_emb.size(0) != text_emb.size(0):
|
|
|
cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
|
|
|
|
|
|
|
|
|
embeds = torch.stack([
|
|
|
torch.cat((ce, te, se))
|
|
|
for ce, te, se in zip(cond_emb, text_emb, speech_emb)
|
|
|
])
|
|
|
return embeds, len_cond
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
*,
|
|
|
t3_cond: T3Cond,
|
|
|
text_tokens: torch.LongTensor,
|
|
|
text_token_lens: torch.LongTensor,
|
|
|
speech_tokens: torch.LongTensor,
|
|
|
speech_token_lens: torch.LongTensor,
|
|
|
training=False,
|
|
|
):
|
|
|
_ensure_BOT_EOT(text_tokens, self.hp)
|
|
|
|
|
|
|
|
|
embeds, len_cond = self.prepare_input_embeds(
|
|
|
t3_cond=t3_cond,
|
|
|
text_tokens=text_tokens,
|
|
|
speech_tokens=speech_tokens,
|
|
|
)
|
|
|
|
|
|
|
|
|
tfmr_out = self.tfmr.forward(
|
|
|
input_ids=None,
|
|
|
|
|
|
inputs_embeds=embeds,
|
|
|
output_hidden_states=True,
|
|
|
return_dict=True,
|
|
|
use_cache=(not training),
|
|
|
)
|
|
|
hidden_states = tfmr_out.hidden_states[-1]
|
|
|
|
|
|
|
|
|
len_text = text_tokens.size(1)
|
|
|
len_speech = speech_tokens.size(1)
|
|
|
B, _, dim = hidden_states.shape
|
|
|
device, dtype = hidden_states.device, hidden_states.dtype
|
|
|
text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device)
|
|
|
speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device)
|
|
|
ttl, stl = text_token_lens, speech_token_lens
|
|
|
for i in range(B):
|
|
|
text_end = len_cond + ttl[i].item()
|
|
|
speech_start = len_cond + text_tokens.size(1)
|
|
|
speech_end = speech_start + stl[i].item()
|
|
|
text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
|
|
|
speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
|
|
|
|
|
|
|
|
|
text_logits = self.text_head(text_latents)
|
|
|
speech_logits = self.speech_head(speech_latents)
|
|
|
|
|
|
return AttrDict(
|
|
|
text_logits=text_logits,
|
|
|
text_latents=text_latents,
|
|
|
speech_logits=speech_logits,
|
|
|
speech_latents=speech_latents,
|
|
|
hidden_states=hidden_states,
|
|
|
)
|
|
|
|
|
|
def loss(
|
|
|
self,
|
|
|
*,
|
|
|
t3_cond: T3Cond,
|
|
|
text_tokens: torch.LongTensor,
|
|
|
text_token_lens: torch.LongTensor,
|
|
|
speech_tokens: torch.LongTensor,
|
|
|
speech_token_lens: torch.LongTensor,
|
|
|
):
|
|
|
"training method"
|
|
|
len_text = text_tokens.size(1)
|
|
|
len_speech = speech_tokens.size(1)
|
|
|
assert len_text == text_token_lens.max()
|
|
|
assert len_speech == speech_token_lens.max()
|
|
|
|
|
|
out = self.forward(
|
|
|
t3_cond=t3_cond,
|
|
|
text_tokens=text_tokens,
|
|
|
text_token_lens=text_token_lens,
|
|
|
speech_tokens=speech_tokens,
|
|
|
speech_token_lens=speech_token_lens,
|
|
|
training=True,
|
|
|
)
|
|
|
|
|
|
|
|
|
IGNORE_ID = -100
|
|
|
device = out.text_logits.device
|
|
|
mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None]
|
|
|
mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None]
|
|
|
masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
|
|
|
masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
|
|
|
loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
|
|
|
loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)
|
|
|
|
|
|
return loss_text, loss_speech
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def inference(
|
|
|
self,
|
|
|
*,
|
|
|
t3_cond: T3Cond,
|
|
|
text_tokens: Tensor,
|
|
|
initial_speech_tokens: Optional[Tensor]=None,
|
|
|
|
|
|
|
|
|
prepend_prompt_speech_tokens: Optional[Tensor]=None,
|
|
|
|
|
|
|
|
|
num_return_sequences=1,
|
|
|
max_new_tokens=None,
|
|
|
stop_on_eos=True,
|
|
|
do_sample=True,
|
|
|
temperature=0.8,
|
|
|
length_penalty=1.0,
|
|
|
repetition_penalty=1.2,
|
|
|
min_p=0.05,
|
|
|
top_p=1.0,
|
|
|
cfg_weight=0,
|
|
|
):
|
|
|
"""
|
|
|
Args:
|
|
|
text_tokens: a 1D (unbatched) or 2D (batched) tensor.
|
|
|
"""
|
|
|
|
|
|
assert prepend_prompt_speech_tokens is None, "not implemented"
|
|
|
_ensure_BOT_EOT(text_tokens, self.hp)
|
|
|
text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
|
|
|
|
|
|
|
|
|
if initial_speech_tokens is None:
|
|
|
initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
|
|
|
|
|
|
|
|
|
embeds, len_cond = self.prepare_input_embeds(
|
|
|
t3_cond=t3_cond,
|
|
|
text_tokens=text_tokens,
|
|
|
speech_tokens=initial_speech_tokens,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.compiled = False
|
|
|
|
|
|
|
|
|
|
|
|
if not self.compiled:
|
|
|
alignment_stream_analyzer = AlignmentStreamAnalyzer(
|
|
|
self.tfmr,
|
|
|
None,
|
|
|
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
|
|
alignment_layer_idx=9,
|
|
|
eos_idx=self.hp.stop_speech_token,
|
|
|
)
|
|
|
patched_model = T3HuggingfaceBackend(
|
|
|
config=self.cfg,
|
|
|
llama=self.tfmr,
|
|
|
speech_enc=self.speech_emb,
|
|
|
speech_head=self.speech_head,
|
|
|
alignment_stream_analyzer=alignment_stream_analyzer,
|
|
|
)
|
|
|
self.patched_model = patched_model
|
|
|
self.compiled = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = embeds.device
|
|
|
|
|
|
bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
|
|
|
bos_embed = self.speech_emb(bos_token)
|
|
|
bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
|
|
|
|
|
|
|
|
|
bos_embed = torch.cat([bos_embed, bos_embed])
|
|
|
|
|
|
|
|
|
inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
|
|
|
|
|
|
|
|
|
generated_ids = bos_token.clone()
|
|
|
predicted = []
|
|
|
|
|
|
|
|
|
min_p_warper = MinPLogitsWarper(min_p=min_p)
|
|
|
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
|
|
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
|
|
|
|
|
|
|
|
|
output = self.patched_model(
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
past_key_values=None,
|
|
|
use_cache=True,
|
|
|
output_attentions=True,
|
|
|
output_hidden_states=True,
|
|
|
return_dict=True,
|
|
|
)
|
|
|
|
|
|
past = output.past_key_values
|
|
|
|
|
|
|
|
|
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
|
|
|
logits = output.logits[:, -1, :]
|
|
|
|
|
|
|
|
|
logits_cond = logits[0:1]
|
|
|
logits_uncond = logits[1:2]
|
|
|
logits = logits_cond + cfg_weight * (logits_cond - logits_uncond)
|
|
|
logits = logits.squeeze(1)
|
|
|
|
|
|
|
|
|
if temperature != 1.0:
|
|
|
logits = logits / temperature
|
|
|
|
|
|
|
|
|
logits = repetition_penalty_processor(generated_ids, logits)
|
|
|
logits = min_p_warper(None, logits)
|
|
|
logits = top_p_warper(None, logits)
|
|
|
|
|
|
|
|
|
probs = torch.softmax(logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
predicted.append(next_token)
|
|
|
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
|
|
|
|
|
|
|
|
if next_token.view(-1) == self.hp.stop_speech_token:
|
|
|
break
|
|
|
|
|
|
|
|
|
next_token_embed = self.speech_emb(next_token)
|
|
|
next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
|
|
|
|
|
|
|
|
|
next_token_embed = torch.cat([next_token_embed, next_token_embed])
|
|
|
|
|
|
|
|
|
output = self.patched_model(
|
|
|
inputs_embeds=next_token_embed,
|
|
|
past_key_values=past,
|
|
|
output_attentions=True,
|
|
|
output_hidden_states=True,
|
|
|
return_dict=True,
|
|
|
)
|
|
|
|
|
|
past = output.past_key_values
|
|
|
|
|
|
|
|
|
predicted_tokens = torch.cat(predicted, dim=1)
|
|
|
return predicted_tokens
|
|
|
|