Skip to content

Instantly share code, notes, and snippets.

@vpnry
Created May 25, 2023 01:42
Show Gist options
  • Select an option

  • Save vpnry/0ea73a9eb168953d464ce14dc95aa713 to your computer and use it in GitHub Desktop.

Select an option

Save vpnry/0ea73a9eb168953d464ce14dc95aa713 to your computer and use it in GitHub Desktop.
MMS speech fairseq TTS app
'''
Run with
PYTHONPATH=$PYTHONPATH:./fairseq/vits python3 app.py --model_dir LANG --file_path text.txt
LANG is your model language
'''
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from underthesea import text_normalize as vie_text_normalize
from underthesea import sent_tokenize as vie_sent_tokenize # Vietnamese NLP toolkit
from nltk import sent_tokenize as nltk_sent_tokenize
import os
import re
import glob
import json
import tempfile
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import numpy as np
import commons
import utils
import argparse
import subprocess
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from scipy.io.wavfile import write
import nltk
nltk.download('punkt')
class TextMapper(object):
def __init__(self, vocab_file):
self.symbols = [x.replace("\n", "")
for x in open(vocab_file).readlines()]
self.SPACE_ID = self.symbols.index(" ")
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
def text_to_sequence(self, text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
clean_text = text.strip()
for symbol in clean_text:
symbol_id = self._symbol_to_id[symbol]
sequence += [symbol_id]
return sequence
def get_text(self, text, hps):
text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def filter_oov(self, text):
val_chars = self._symbol_to_id
txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
print(f"text after filtering OOV: {txt_filt}")
return txt_filt
def load_model(model_dir, uroman_dir=None):
vocab_file = f"{model_dir}/vocab.txt"
config_file = f"{model_dir}/config.json"
assert os.path.isfile(config_file), f"{config_file} doesn't exist"
hps = utils.get_hparams_from_file(config_file)
text_mapper = TextMapper(vocab_file)
net_g = SynthesizerTrn(
len(text_mapper.symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
if torch.cuda.is_available():
torch_device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
torch_device = torch.device("mps")
else:
torch_device = torch.device("cpu")
net_g.to(torch_device)
_ = net_g.eval()
g_pth = f"{model_dir}/G_100000.pth"
print(f"Loading model: {g_pth}")
_ = utils.load_checkpoint(g_pth, net_g, None)
uroman_pl = None
is_uroman = hps.data.training_files.split('.')[-1] == 'uroman'
if is_uroman:
print("Loading uroman")
with tempfile.TemporaryDirectory() as tmp_dir:
if args.uroman_dir is None:
cmd = f"git clone git@github.com:isi-nlp/uroman.git {tmp_dir}"
print(cmd)
subprocess.check_output(cmd, shell=True)
args.uroman_dir = tmp_dir
uroman_pl = os.path.join(args.uroman_dir, "bin", "uroman.pl")
return net_g, text_mapper, hps, torch_device, uroman_pl
def tts(model, wav_path, text):
net_g, text_mapper, hps, torch_device, uroman_pl = model
if uroman_pl:
print(f"uromanize")
txt = text_mapper.uromanize(txt, uroman_pl)
print(f"uroman text: {txt}")
txt = text.lower()
txt = text_mapper.filter_oov(txt)
stn_tst = text_mapper.get_text(txt, hps)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(torch_device)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(torch_device)
hyp = net_g.infer(
x_tst, x_tst_lengths, noise_scale=.667,
noise_scale_w=0.8, length_scale=1.0
)[0][0, 0].cpu().float().numpy()
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
print(f"Saving WAV: {wav_path}")
write(wav_path, hps.data.sampling_rate, hyp)
def read_and_normalize_file(model_lang_dir, filename):
# Read the file
with open(filename, 'r', encoding='utf-8') as file:
text = file.read()
# Tokenize the sentences
sentences = []
paragraphs = [paragraph for paragraph in text.split(
"\n") if paragraph.strip()]
if model_lang_dir.lower() == "vie":
for paragraph in paragraphs:
sentences_raw = vie_sent_tokenize(paragraph)
sentences.extend([vie_text_normalize(sentence)
for sentence in sentences_raw if sentence.strip()])
else:
sentences = [sentence
for paragraph in paragraphs
for sentence in nltk_sent_tokenize(paragraph) if sentence.strip()]
return sentences
def combine_wav(model_lang_dir, output_dir, sentences):
combined_output_file = f"./{model_lang_dir}_combined.wav"
combined_files_cmd = f"sox {' '.join([f'{output_dir}/sentence_{i}.wav' for i in range(len(sentences))])} {combined_output_file}"
os.system(combined_files_cmd)
print("\n\n--- Done! Final one WAV file:", combined_output_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TTS inference')
parser.add_argument('--model_dir', type=str, help='model checkpoint dir')
parser.add_argument('--file_path', type=str, help='input file path')
parser.add_argument('--uroman-dir', type=str,
help='uroman lib dir (will download if not specified)')
args = parser.parse_args()
model_lang_dir, file_path, uroman_dir = args.model_dir, args.file_path, args.uroman_dir
output_dir = f"./output_{model_lang_dir}"
os.makedirs(output_dir, exist_ok=True)
sentences = read_and_normalize_file(model_lang_dir, file_path)
print("Total sentences: {}".format(len(sentences)))
model_dir = f"./{model_lang_dir}"
model = load_model(model_dir, uroman_dir)
for i, sentence in enumerate(sentences):
output_wav_path = f"{output_dir}/sentence_{i}.wav"
tts(model, output_wav_path, sentence)
combine_wav(model_lang_dir, output_dir, sentences)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment