Created
May 25, 2023 01:42
-
-
Save vpnry/0ea73a9eb168953d464ce14dc95aa713 to your computer and use it in GitHub Desktop.
MMS speech fairseq TTS app
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| Run with | |
| PYTHONPATH=$PYTHONPATH:./fairseq/vits python3 app.py --model_dir LANG --file_path text.txt | |
| LANG is your model language | |
| ''' | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from underthesea import text_normalize as vie_text_normalize | |
| from underthesea import sent_tokenize as vie_sent_tokenize # Vietnamese NLP toolkit | |
| from nltk import sent_tokenize as nltk_sent_tokenize | |
| import os | |
| import re | |
| import glob | |
| import json | |
| import tempfile | |
| import math | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| import commons | |
| import utils | |
| import argparse | |
| import subprocess | |
| from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate | |
| from models import SynthesizerTrn | |
| from scipy.io.wavfile import write | |
| import nltk | |
| nltk.download('punkt') | |
| class TextMapper(object): | |
| def __init__(self, vocab_file): | |
| self.symbols = [x.replace("\n", "") | |
| for x in open(vocab_file).readlines()] | |
| self.SPACE_ID = self.symbols.index(" ") | |
| self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} | |
| self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)} | |
| def text_to_sequence(self, text, cleaner_names): | |
| '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. | |
| Args: | |
| text: string to convert to a sequence | |
| cleaner_names: names of the cleaner functions to run the text through | |
| Returns: | |
| List of integers corresponding to the symbols in the text | |
| ''' | |
| sequence = [] | |
| clean_text = text.strip() | |
| for symbol in clean_text: | |
| symbol_id = self._symbol_to_id[symbol] | |
| sequence += [symbol_id] | |
| return sequence | |
| def get_text(self, text, hps): | |
| text_norm = self.text_to_sequence(text, hps.data.text_cleaners) | |
| if hps.data.add_blank: | |
| text_norm = commons.intersperse(text_norm, 0) | |
| text_norm = torch.LongTensor(text_norm) | |
| return text_norm | |
| def filter_oov(self, text): | |
| val_chars = self._symbol_to_id | |
| txt_filt = "".join(list(filter(lambda x: x in val_chars, text))) | |
| print(f"text after filtering OOV: {txt_filt}") | |
| return txt_filt | |
| def load_model(model_dir, uroman_dir=None): | |
| vocab_file = f"{model_dir}/vocab.txt" | |
| config_file = f"{model_dir}/config.json" | |
| assert os.path.isfile(config_file), f"{config_file} doesn't exist" | |
| hps = utils.get_hparams_from_file(config_file) | |
| text_mapper = TextMapper(vocab_file) | |
| net_g = SynthesizerTrn( | |
| len(text_mapper.symbols), | |
| hps.data.filter_length // 2 + 1, | |
| hps.train.segment_size // hps.data.hop_length, | |
| **hps.model) | |
| if torch.cuda.is_available(): | |
| torch_device = torch.device("cuda") | |
| elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
| torch_device = torch.device("mps") | |
| else: | |
| torch_device = torch.device("cpu") | |
| net_g.to(torch_device) | |
| _ = net_g.eval() | |
| g_pth = f"{model_dir}/G_100000.pth" | |
| print(f"Loading model: {g_pth}") | |
| _ = utils.load_checkpoint(g_pth, net_g, None) | |
| uroman_pl = None | |
| is_uroman = hps.data.training_files.split('.')[-1] == 'uroman' | |
| if is_uroman: | |
| print("Loading uroman") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| if args.uroman_dir is None: | |
| cmd = f"git clone git@github.com:isi-nlp/uroman.git {tmp_dir}" | |
| print(cmd) | |
| subprocess.check_output(cmd, shell=True) | |
| args.uroman_dir = tmp_dir | |
| uroman_pl = os.path.join(args.uroman_dir, "bin", "uroman.pl") | |
| return net_g, text_mapper, hps, torch_device, uroman_pl | |
| def tts(model, wav_path, text): | |
| net_g, text_mapper, hps, torch_device, uroman_pl = model | |
| if uroman_pl: | |
| print(f"uromanize") | |
| txt = text_mapper.uromanize(txt, uroman_pl) | |
| print(f"uroman text: {txt}") | |
| txt = text.lower() | |
| txt = text_mapper.filter_oov(txt) | |
| stn_tst = text_mapper.get_text(txt, hps) | |
| with torch.no_grad(): | |
| x_tst = stn_tst.unsqueeze(0).to(torch_device) | |
| x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(torch_device) | |
| hyp = net_g.infer( | |
| x_tst, x_tst_lengths, noise_scale=.667, | |
| noise_scale_w=0.8, length_scale=1.0 | |
| )[0][0, 0].cpu().float().numpy() | |
| os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
| print(f"Saving WAV: {wav_path}") | |
| write(wav_path, hps.data.sampling_rate, hyp) | |
| def read_and_normalize_file(model_lang_dir, filename): | |
| # Read the file | |
| with open(filename, 'r', encoding='utf-8') as file: | |
| text = file.read() | |
| # Tokenize the sentences | |
| sentences = [] | |
| paragraphs = [paragraph for paragraph in text.split( | |
| "\n") if paragraph.strip()] | |
| if model_lang_dir.lower() == "vie": | |
| for paragraph in paragraphs: | |
| sentences_raw = vie_sent_tokenize(paragraph) | |
| sentences.extend([vie_text_normalize(sentence) | |
| for sentence in sentences_raw if sentence.strip()]) | |
| else: | |
| sentences = [sentence | |
| for paragraph in paragraphs | |
| for sentence in nltk_sent_tokenize(paragraph) if sentence.strip()] | |
| return sentences | |
| def combine_wav(model_lang_dir, output_dir, sentences): | |
| combined_output_file = f"./{model_lang_dir}_combined.wav" | |
| combined_files_cmd = f"sox {' '.join([f'{output_dir}/sentence_{i}.wav' for i in range(len(sentences))])} {combined_output_file}" | |
| os.system(combined_files_cmd) | |
| print("\n\n--- Done! Final one WAV file:", combined_output_file) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='TTS inference') | |
| parser.add_argument('--model_dir', type=str, help='model checkpoint dir') | |
| parser.add_argument('--file_path', type=str, help='input file path') | |
| parser.add_argument('--uroman-dir', type=str, | |
| help='uroman lib dir (will download if not specified)') | |
| args = parser.parse_args() | |
| model_lang_dir, file_path, uroman_dir = args.model_dir, args.file_path, args.uroman_dir | |
| output_dir = f"./output_{model_lang_dir}" | |
| os.makedirs(output_dir, exist_ok=True) | |
| sentences = read_and_normalize_file(model_lang_dir, file_path) | |
| print("Total sentences: {}".format(len(sentences))) | |
| model_dir = f"./{model_lang_dir}" | |
| model = load_model(model_dir, uroman_dir) | |
| for i, sentence in enumerate(sentences): | |
| output_wav_path = f"{output_dir}/sentence_{i}.wav" | |
| tts(model, output_wav_path, sentence) | |
| combine_wav(model_lang_dir, output_dir, sentences) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment