Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from dataclasses import asdict | |
| from utils.audio import LogMelSpectrogram | |
| from config import ModelConfig, MelConfig | |
| from models.model import StableTTS | |
| from text import symbols | |
| from text import cleaned_text_to_sequence | |
| from text.mandarin import chinese_to_cnm3 | |
| from text.english import english_to_ipa2 | |
| from text.japanese import japanese_to_ipa2 | |
| from datas.dataset import intersperse | |
| from utils.audio import load_and_resample_audio | |
| def get_vocoder(model_path, model_name='ffgan') -> nn.Module: | |
| if model_name == 'ffgan': | |
| # training or changing ffgan config is not supported in this repo | |
| # you can train your own model at https://github.com/fishaudio/vocoder | |
| from vocoders.ffgan.model import FireflyGANBaseWrapper | |
| vocoder = FireflyGANBaseWrapper(model_path) | |
| elif model_name == 'vocos': | |
| from vocoders.vocos.models.model import Vocos | |
| from config import VocosConfig, MelConfig | |
| vocoder = Vocos(VocosConfig(), MelConfig()) | |
| vocoder.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu')) | |
| vocoder.eval() | |
| else: | |
| raise NotImplementedError(f"Unsupported model: {model_name}") | |
| return vocoder | |
| class StableTTSAPI(nn.Module): | |
| def __init__(self, tts_model_path, vocoder_model_path, vocoder_name='ffgan'): | |
| super().__init__() | |
| self.mel_config = MelConfig() | |
| self.tts_model_config = ModelConfig() | |
| self.mel_extractor = LogMelSpectrogram(**asdict(self.mel_config)) | |
| # text to mel spectrogram | |
| self.tts_model = StableTTS(len(symbols), self.mel_config.n_mels, **asdict(self.tts_model_config)) | |
| self.tts_model.load_state_dict(torch.load(tts_model_path, map_location='cpu', weights_only=True)) | |
| self.tts_model.eval() | |
| # mel spectrogram to waveform | |
| self.vocoder_model = get_vocoder(vocoder_model_path, vocoder_name) | |
| self.vocoder_model.eval() | |
| self.g2p_mapping = { | |
| 'chinese': chinese_to_cnm3, | |
| 'japanese': japanese_to_ipa2, | |
| 'english': english_to_ipa2, | |
| } | |
| self.supported_languages = self.g2p_mapping.keys() | |
| def inference(self, text, ref_audio, language, step, temperature=1.0, length_scale=1.0, solver=None, cfg=3.0): | |
| device = next(self.parameters()).device | |
| phonemizer = self.g2p_mapping.get(language) | |
| text = phonemizer(text) | |
| text = torch.tensor(intersperse(cleaned_text_to_sequence(text), item=0), dtype=torch.long, device=device).unsqueeze(0) | |
| text_length = torch.tensor([text.size(-1)], dtype=torch.long, device=device) | |
| ref_audio = load_and_resample_audio(ref_audio, self.mel_config.sample_rate).to(device) | |
| ref_audio = self.mel_extractor(ref_audio) | |
| mel_output = self.tts_model.synthesise(text, text_length, step, temperature, ref_audio, length_scale, solver, cfg)['decoder_outputs'] | |
| audio_output = self.vocoder_model(mel_output) | |
| return audio_output.cpu(), mel_output.cpu() | |
| def get_params(self): | |
| tts_param = sum(p.numel() for p in self.tts_model.parameters()) / 1e6 | |
| vocoder_param = sum(p.numel() for p in self.vocoder_model.parameters()) / 1e6 | |
| return tts_param, vocoder_param | |
| if __name__ == '__main__': | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| tts_model_path = './checkpoints/checkpoint_0.pt' | |
| vocoder_model_path = './vocoders/pretrained/vocos.pt' | |
| model = StableTTSAPI(tts_model_path, vocoder_model_path, 'vocos') | |
| model.to(device) | |
| text = '樱落满殇祈念集……殇歌花落集思祈……樱花满地集于我心……揲舞纷飞祈愿相随……' | |
| audio = './audio_1.wav' | |
| audio_output, mel_output = model.inference(text, audio, 'chinese', 10, solver='dopri5', cfg=3) | |
| print(audio_output.shape) | |
| print(mel_output.shape) | |
| import torchaudio | |
| torchaudio.save('output.wav', audio_output, MelConfig().sample_rate) | |