diff --git a/TTS/VERSION b/TTS/VERSION new file mode 100644 index 0000000..2157409 --- /dev/null +++ b/TTS/VERSION @@ -0,0 +1 @@ +0.22.0 diff --git a/TTS/api.py b/TTS/api.py new file mode 100644 index 0000000..7abc188 --- /dev/null +++ b/TTS/api.py @@ -0,0 +1,458 @@ +import tempfile +import warnings +from pathlib import Path +from typing import Union + +import numpy as np +from torch import nn + +from TTS.utils.audio.numpy_transforms import save_wav +from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer +from TTS.config import load_config + + +class TTS(nn.Module): + """TODO: Add voice conversion and Capacitron support.""" + + def __init__( + self, + model_name: str = "", + model_path: str = None, + config_path: str = None, + vocoder_path: str = None, + vocoder_config_path: str = None, + progress_bar: bool = True, + gpu=False, + ): + """🐸TTS python interface that allows to load and use the released models. + + Example with a multi-speaker model: + >>> from TTS.api import TTS + >>> tts = TTS(TTS.list_models()[0]) + >>> wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0]) + >>> tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav") + + Example with a single-speaker model: + >>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False) + >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + + Example loading a model from a path: + >>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False) + >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + + Example voice cloning with YourTTS in English, French and Portuguese: + >>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) + >>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav") + >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav") + >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav") + + Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html): + >>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True) + >>> tts.tts_to_file("This is a test.", file_path="output.wav") + + Args: + model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. + model_path (str, optional): Path to the model checkpoint. Defaults to None. + config_path (str, optional): Path to the model config. Defaults to None. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None. + progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + super().__init__() + self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) + self.config = load_config(config_path) if config_path else None + self.synthesizer = None + self.voice_converter = None + self.model_name = "" + if gpu: + warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") + + if model_name is not None and len(model_name) > 0: + if "tts_models" in model_name: + self.load_tts_model_by_name(model_name, gpu) + elif "voice_conversion_models" in model_name: + self.load_vc_model_by_name(model_name, gpu) + else: + self.load_model_by_name(model_name, gpu) + + if model_path: + self.load_tts_model_by_path( + model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu + ) + + @property + def models(self): + return self.manager.list_tts_models() + + @property + def is_multi_speaker(self): + if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager: + return self.synthesizer.tts_model.speaker_manager.num_speakers > 1 + return False + + @property + def is_multi_lingual(self): + # Not sure what sets this to None, but applied a fix to prevent crashing. + if ( + isinstance(self.model_name, str) + and "xtts" in self.model_name + or self.config + and ("xtts" in self.config.model or len(self.config.languages) > 1) + ): + return True + if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager: + return self.synthesizer.tts_model.language_manager.num_languages > 1 + return False + + @property + def speakers(self): + if not self.is_multi_speaker: + return None + return self.synthesizer.tts_model.speaker_manager.speaker_names + + @property + def languages(self): + if not self.is_multi_lingual: + return None + return self.synthesizer.tts_model.language_manager.language_names + + @staticmethod + def get_models_file_path(): + return Path(__file__).parent / ".models.json" + + def list_models(self): + return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False) + + def download_model_by_name(self, model_name: str): + model_path, config_path, model_item = self.manager.download_model(model_name) + if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)): + # return model directory if there are multiple files + # we assume that the model knows how to load itself + return None, None, None, None, model_path + if model_item.get("default_vocoder") is None: + return model_path, config_path, None, None, None + vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"]) + return model_path, config_path, vocoder_path, vocoder_config_path, None + + def load_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of the 🐸TTS models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + self.load_tts_model_by_name(model_name, gpu) + + def load_vc_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of the voice conversion models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + self.model_name = model_name + model_path, config_path, _, _, _ = self.download_model_by_name(model_name) + self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) + + def load_tts_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of 🐸TTS models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + + TODO: Add tests + """ + self.synthesizer = None + self.model_name = model_name + + model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( + model_name + ) + + # init synthesizer + # None values are fetch from the model + self.synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=None, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint=None, + encoder_config=None, + model_dir=model_dir, + use_cuda=gpu, + ) + + def load_tts_model_by_path( + self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False + ): + """Load a model from a path. + + Args: + model_path (str): Path to the model checkpoint. + config_path (str): Path to the model config. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config (str, optional): Path to the vocoder config. Defaults to None. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + + self.synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=None, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config, + encoder_checkpoint=None, + encoder_config=None, + use_cuda=gpu, + ) + + def _check_arguments( + self, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + emotion: str = None, + speed: float = None, + **kwargs, + ) -> None: + """Check if the arguments are valid for the model.""" + # check for the coqui tts models + if self.is_multi_speaker and (speaker is None and speaker_wav is None): + raise ValueError("Model is multi-speaker but no `speaker` is provided.") + if self.is_multi_lingual and language is None: + raise ValueError("Model is multi-lingual but no `language` is provided.") + if not self.is_multi_speaker and speaker is not None and "voice_dir" not in kwargs: + raise ValueError("Model is not multi-speaker but `speaker` is provided.") + if not self.is_multi_lingual and language is not None: + raise ValueError("Model is not multi-lingual but `language` is provided.") + if not emotion is None and not speed is None: + raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.") + + def tts( + self, + text: str, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + emotion: str = None, + speed: float = None, + split_sentences: bool = True, + **kwargs, + ): + """Convert text to speech. + + Args: + text (str): + Input text to synthesize. + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + language (str): Language of the text. If None, the default language of the speaker is used. Language is only + supported by `XTTS` model. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + emotion (str, optional): + Emotion to use for 🐸Coqui Studio models. If None, Studio models use "Neutral". Defaults to None. + speed (float, optional): + Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0. + Defaults to None. + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + kwargs (dict, optional): + Additional arguments for the model. + """ + self._check_arguments( + speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs + ) + wav = self.synthesizer.tts( + text=text, + speaker_name=speaker, + language_name=language, + speaker_wav=speaker_wav, + reference_wav=None, + style_wav=None, + style_text=None, + reference_speaker_name=None, + split_sentences=split_sentences, + **kwargs, + ) + return wav + + def tts_to_file( + self, + text: str, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + emotion: str = None, + speed: float = 1.0, + pipe_out=None, + file_path: str = "output.wav", + split_sentences: bool = True, + **kwargs, + ): + """Convert text to speech. + + Args: + text (str): + Input text to synthesize. + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + emotion (str, optional): + Emotion to use for 🐸Coqui Studio models. Defaults to "Neutral". + speed (float, optional): + Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None. + pipe_out (BytesIO, optional): + Flag to stdout the generated TTS wav file for shell pipe. + file_path (str, optional): + Output file path. Defaults to "output.wav". + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + kwargs (dict, optional): + Additional arguments for the model. + """ + self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs) + + wav = self.tts( + text=text, + speaker=speaker, + language=language, + speaker_wav=speaker_wav, + split_sentences=split_sentences, + **kwargs, + ) + self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out) + return file_path + + def voice_conversion( + self, + source_wav: str, + target_wav: str, + ): + """Voice conversion with FreeVC. Convert source wav to target speaker. + + Args:`` + source_wav (str): + Path to the source wav file. + target_wav (str):` + Path to the target wav file. + """ + wav = self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav) + return wav + + def voice_conversion_to_file( + self, + source_wav: str, + target_wav: str, + file_path: str = "output.wav", + ): + """Voice conversion with FreeVC. Convert source wav to target speaker. + + Args: + source_wav (str): + Path to the source wav file. + target_wav (str): + Path to the target wav file. + file_path (str, optional): + Output file path. Defaults to "output.wav". + """ + wav = self.voice_conversion(source_wav=source_wav, target_wav=target_wav) + save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) + return file_path + + def tts_with_vc( + self, + text: str, + language: str = None, + speaker_wav: str = None, + speaker: str = None, + split_sentences: bool = True, + ): + """Convert text to speech with voice conversion. + + It combines tts with voice conversion to fake voice cloning. + + - Convert text to speech with tts. + - Convert the output wav to target speaker with voice conversion. + + Args: + text (str): + Input text to synthesize. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + """ + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + # Lazy code... save it to a temp file to resample it while reading it for VC + self.tts_to_file( + text=text, speaker=speaker, language=language, file_path=fp.name, split_sentences=split_sentences + ) + if self.voice_converter is None: + self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24") + wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav) + return wav + + def tts_with_vc_to_file( + self, + text: str, + language: str = None, + speaker_wav: str = None, + file_path: str = "output.wav", + speaker: str = None, + split_sentences: bool = True, + ): + """Convert text to speech with voice conversion and save to file. + + Check `tts_with_vc` for more details. + + Args: + text (str): + Input text to synthesize. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + file_path (str, optional): + Output file path. Defaults to "output.wav". + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + """ + wav = self.tts_with_vc( + text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences + ) + save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) diff --git a/TTS/bin/__init__.py b/TTS/bin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TTS/bin/collect_env_info.py b/TTS/bin/collect_env_info.py new file mode 100644 index 0000000..662fcd0 --- /dev/null +++ b/TTS/bin/collect_env_info.py @@ -0,0 +1,48 @@ +"""Get detailed info about the working environment.""" +import os +import platform +import sys + +import numpy +import torch + +sys.path += [os.path.abspath(".."), os.path.abspath(".")] +import json + +import TTS + + +def system_info(): + return { + "OS": platform.system(), + "architecture": platform.architecture(), + "version": platform.version(), + "processor": platform.processor(), + "python": platform.python_version(), + } + + +def cuda_info(): + return { + "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())], + "available": torch.cuda.is_available(), + "version": torch.version.cuda, + } + + +def package_info(): + return { + "numpy": numpy.__version__, + "PyTorch_version": torch.__version__, + "PyTorch_debug": torch.version.debug, + "TTS": TTS.__version__, + } + + +def main(): + details = {"System": system_info(), "CUDA": cuda_info(), "Packages": package_info()} + print(json.dumps(details, indent=4, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py new file mode 100644 index 0000000..9ab520b --- /dev/null +++ b/TTS/bin/compute_attention_masks.py @@ -0,0 +1,165 @@ +import argparse +import importlib +import os +from argparse import RawTextHelpFormatter + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets.TTSDataset import TTSDataset +from TTS.tts.models import setup_model +from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_checkpoint + +if __name__ == "__main__": + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Extract attention masks from trained Tacotron/Tacotron2 models. +These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" + """Each attention mask is written to the same path as the input wav file with ".npy" file extension. +(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" + """ +Example run: + CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py + --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth + --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json + --dataset_metafile metadata.csv + --data_path /root/LJSpeech-1.1/ + --batch_size 32 + --dataset ljspeech + --use_cuda True +""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") + parser.add_argument( + "--config_path", + type=str, + required=True, + help="Path to Tacotron/Tacotron2 config file.", + ) + parser.add_argument( + "--dataset", + type=str, + default="", + required=True, + help="Target dataset processor name from TTS.tts.dataset.preprocess.", + ) + + parser.add_argument( + "--dataset_metafile", + type=str, + default="", + required=True, + help="Dataset metafile inclusing file paths with transcripts.", + ) + parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") + parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.") + + parser.add_argument( + "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." + ) + args = parser.parse_args() + + C = load_config(args.config_path) + ap = AudioProcessor(**C.audio) + + # if the vocabulary was passed, replace the default + if "characters" in C.keys(): + symbols, phonemes = make_symbols(**C.characters) + + # load the model + num_chars = len(phonemes) if C.use_phonemes else len(symbols) + # TODO: handle multi-speaker + model = setup_model(C) + model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) + + # data loader + preprocessor = importlib.import_module("TTS.tts.datasets.formatters") + preprocessor = getattr(preprocessor, args.dataset) + meta_data = preprocessor(args.data_path, args.dataset_metafile) + dataset = TTSDataset( + model.decoder.r, + C.text_cleaner, + compute_linear_spec=False, + ap=ap, + meta_data=meta_data, + characters=C.characters if "characters" in C.keys() else None, + add_blank=C["add_blank"] if "add_blank" in C.keys() else False, + use_phonemes=C.use_phonemes, + phoneme_cache_path=C.phoneme_cache_path, + phoneme_language=C.phoneme_language, + enable_eos_bos=C.enable_eos_bos_chars, + ) + + dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) + loader = DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=4, + collate_fn=dataset.collate_fn, + shuffle=False, + drop_last=False, + ) + + # compute attentions + file_paths = [] + with torch.no_grad(): + for data in tqdm(loader): + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[3] + mel_input = data[4] + mel_lengths = data[5] + stop_targets = data[6] + item_idxs = data[7] + + # dispatch data to GPU + if args.use_cuda: + text_input = text_input.cuda() + text_lengths = text_lengths.cuda() + mel_input = mel_input.cuda() + mel_lengths = mel_lengths.cuda() + + model_outputs = model.forward(text_input, text_lengths, mel_input) + + alignments = model_outputs["alignments"].detach() + for idx, alignment in enumerate(alignments): + item_idx = item_idxs[idx] + # interpolate if r > 1 + alignment = ( + torch.nn.functional.interpolate( + alignment.transpose(0, 1).unsqueeze(0), + size=None, + scale_factor=model.decoder.r, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + ) + .squeeze(0) + .transpose(0, 1) + ) + # remove paddings + alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() + # set file paths + wav_file_name = os.path.basename(item_idx) + align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" + file_path = item_idx.replace(wav_file_name, align_file_name) + # save output + wav_file_abs_path = os.path.abspath(item_idx) + file_abs_path = os.path.abspath(file_path) + file_paths.append([wav_file_abs_path, file_abs_path]) + np.save(file_path, alignment) + + # ourput metafile + metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") + + with open(metafile, "w", encoding="utf-8") as f: + for p in file_paths: + f.write(f"{p[0]}|{p[1]}\n") + print(f" >> Metafile created: {metafile}") diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py new file mode 100644 index 0000000..5b5a37d --- /dev/null +++ b/TTS/bin/compute_embeddings.py @@ -0,0 +1,197 @@ +import argparse +import os +from argparse import RawTextHelpFormatter + +import torch +from tqdm import tqdm + +from TTS.config import load_config +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.managers import save_file +from TTS.tts.utils.speakers import SpeakerManager + + +def compute_embeddings( + model_path, + config_path, + output_path, + old_speakers_file=None, + old_append=False, + config_dataset_path=None, + formatter_name=None, + dataset_name=None, + dataset_path=None, + meta_file_train=None, + meta_file_val=None, + disable_cuda=False, + no_eval=False, +): + use_cuda = torch.cuda.is_available() and not disable_cuda + + if config_dataset_path is not None: + c_dataset = load_config(config_dataset_path) + meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not no_eval) + else: + c_dataset = BaseDatasetConfig() + c_dataset.formatter = formatter_name + c_dataset.dataset_name = dataset_name + c_dataset.path = dataset_path + if meta_file_train is not None: + c_dataset.meta_file_train = meta_file_train + if meta_file_val is not None: + c_dataset.meta_file_val = meta_file_val + meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not no_eval) + + if meta_data_eval is None: + samples = meta_data_train + else: + samples = meta_data_train + meta_data_eval + + encoder_manager = SpeakerManager( + encoder_model_path=model_path, + encoder_config_path=config_path, + d_vectors_file_path=old_speakers_file, + use_cuda=use_cuda, + ) + + class_name_key = encoder_manager.encoder_config.class_name_key + + # compute speaker embeddings + if old_speakers_file is not None and old_append: + speaker_mapping = encoder_manager.embeddings + else: + speaker_mapping = {} + + for fields in tqdm(samples): + class_name = fields[class_name_key] + audio_file = fields["audio_file"] + embedding_key = fields["audio_unique_name"] + + # Only update the speaker name when the embedding is already in the old file. + if embedding_key in speaker_mapping: + speaker_mapping[embedding_key]["name"] = class_name + continue + + if old_speakers_file is not None and embedding_key in encoder_manager.clip_ids: + # get the embedding from the old file + embedd = encoder_manager.get_embedding_by_clip(embedding_key) + else: + # extract the embedding + embedd = encoder_manager.compute_embedding_from_clip(audio_file) + + # create speaker_mapping if target dataset is defined + speaker_mapping[embedding_key] = {} + speaker_mapping[embedding_key]["name"] = class_name + speaker_mapping[embedding_key]["embedding"] = embedd + + if speaker_mapping: + # save speaker_mapping if target dataset is defined + if os.path.isdir(output_path): + mapping_file_path = os.path.join(output_path, "speakers.pth") + else: + mapping_file_path = output_path + + if os.path.dirname(mapping_file_path) != "": + os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) + + save_file(speaker_mapping, mapping_file_path) + print("Speaker embeddings saved at:", mapping_file_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n""" + """ + Example runs: + python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --config_dataset_path dataset_config.json + + python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --formatter_name coqui --dataset_path /path/to/vctk/dataset --dataset_name my_vctk --meta_file_train /path/to/vctk/metafile_train.csv --meta_file_val /path/to/vctk/metafile_eval.csv + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument( + "--model_path", + type=str, + help="Path to model checkpoint file. It defaults to the released speaker encoder.", + default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar", + ) + parser.add_argument( + "--config_path", + type=str, + help="Path to model config file. It defaults to the released speaker encoder config.", + default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json", + ) + parser.add_argument( + "--config_dataset_path", + type=str, + help="Path to dataset config file. You either need to provide this or `formatter_name`, `dataset_name` and `dataset_path` arguments.", + default=None, + ) + parser.add_argument( + "--output_path", + type=str, + help="Path for output `pth` or `json` file.", + default="speakers.pth", + ) + parser.add_argument( + "--old_file", + type=str, + help="The old existing embedding file, from which the embeddings will be directly loaded for already computed audio clips.", + default=None, + ) + parser.add_argument( + "--old_append", + help="Append new audio clip embeddings to the old embedding file, generate a new non-duplicated merged embedding file. Default False", + default=False, + action="store_true", + ) + parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False) + parser.add_argument("--no_eval", help="Do not compute eval?. Default False", default=False, action="store_true") + parser.add_argument( + "--formatter_name", + type=str, + help="Name of the formatter to use. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--dataset_name", + type=str, + help="Name of the dataset to use. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--dataset_path", + type=str, + help="Path to the dataset. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--meta_file_train", + type=str, + help="Path to the train meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--meta_file_val", + type=str, + help="Path to the evaluation meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`", + default=None, + ) + args = parser.parse_args() + + compute_embeddings( + args.model_path, + args.config_path, + args.output_path, + old_speakers_file=args.old_file, + old_append=args.old_append, + config_dataset_path=args.config_dataset_path, + formatter_name=args.formatter_name, + dataset_name=args.dataset_name, + dataset_path=args.dataset_path, + meta_file_train=args.meta_file_train, + meta_file_val=args.meta_file_val, + disable_cuda=args.disable_cuda, + no_eval=args.no_eval, + ) diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py new file mode 100644 index 0000000..3ab7ea7 --- /dev/null +++ b/TTS/bin/compute_statistics.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import glob +import os + +import numpy as np +from tqdm import tqdm + +# from TTS.utils.io import load_config +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.utils.audio import AudioProcessor + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") + parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.") + parser.add_argument("out_path", type=str, help="save path (directory and filename).") + parser.add_argument( + "--data_path", + type=str, + required=False, + help="folder including the target set of wavs overriding dataset config.", + ) + args, overrides = parser.parse_known_args() + + CONFIG = load_config(args.config_path) + CONFIG.parse_known_args(overrides, relaxed_parser=True) + + # load config + CONFIG.audio.signal_norm = False # do not apply earlier normalization + CONFIG.audio.stats_path = None # discard pre-defined stats + + # load audio processor + ap = AudioProcessor(**CONFIG.audio.to_dict()) + + # load the meta data of target dataset + if args.data_path: + dataset_items = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True) + else: + dataset_items = load_tts_samples(CONFIG.datasets)[0] # take only train data + print(f" > There are {len(dataset_items)} files.") + + mel_sum = 0 + mel_square_sum = 0 + linear_sum = 0 + linear_square_sum = 0 + N = 0 + for item in tqdm(dataset_items): + # compute features + wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"]) + linear = ap.spectrogram(wav) + mel = ap.melspectrogram(wav) + + # compute stats + N += mel.shape[1] + mel_sum += mel.sum(1) + linear_sum += linear.sum(1) + mel_square_sum += (mel**2).sum(axis=1) + linear_square_sum += (linear**2).sum(axis=1) + + mel_mean = mel_sum / N + mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2) + linear_mean = linear_sum / N + linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2) + + output_file_path = args.out_path + stats = {} + stats["mel_mean"] = mel_mean + stats["mel_std"] = mel_scale + stats["linear_mean"] = linear_mean + stats["linear_std"] = linear_scale + + print(f" > Avg mel spec mean: {mel_mean.mean()}") + print(f" > Avg mel spec scale: {mel_scale.mean()}") + print(f" > Avg linear spec mean: {linear_mean.mean()}") + print(f" > Avg linear spec scale: {linear_scale.mean()}") + + # set default config values for mean-var scaling + CONFIG.audio.stats_path = output_file_path + CONFIG.audio.signal_norm = True + # remove redundant values + del CONFIG.audio.max_norm + del CONFIG.audio.min_level_db + del CONFIG.audio.symmetric_norm + del CONFIG.audio.clip_norm + stats["audio_config"] = CONFIG.audio.to_dict() + np.save(output_file_path, stats, allow_pickle=True) + print(f" > stats saved to {output_file_path}") + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py new file mode 100644 index 0000000..60fed13 --- /dev/null +++ b/TTS/bin/eval_encoder.py @@ -0,0 +1,88 @@ +import argparse +from argparse import RawTextHelpFormatter + +import torch +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.speakers import SpeakerManager + + +def compute_encoder_accuracy(dataset_items, encoder_manager): + class_name_key = encoder_manager.encoder_config.class_name_key + map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None) + + class_acc_dict = {} + + # compute embeddings for all wav_files + for item in tqdm(dataset_items): + class_name = item[class_name_key] + wav_file = item["audio_file"] + + # extract the embedding + embedd = encoder_manager.compute_embedding_from_clip(wav_file) + if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None: + embedding = torch.FloatTensor(embedd).unsqueeze(0) + if encoder_manager.use_cuda: + embedding = embedding.cuda() + + class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item() + predicted_label = map_classid_to_classname[str(class_id)] + else: + predicted_label = None + + if class_name is not None and predicted_label is not None: + is_equal = int(class_name == predicted_label) + if class_name not in class_acc_dict: + class_acc_dict[class_name] = [is_equal] + else: + class_acc_dict[class_name].append(is_equal) + else: + raise RuntimeError("Error: class_name or/and predicted_label are None") + + acc_avg = 0 + for key, values in class_acc_dict.items(): + acc = sum(values) / len(values) + print("Class", key, "Accuracy:", acc) + acc_avg += acc + + print("Average Accuracy:", acc_avg / len(class_acc_dict)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Compute the accuracy of the encoder.\n\n""" + """ + Example runs: + python TTS/bin/eval_encoder.py emotion_encoder_model.pth emotion_encoder_config.json dataset_config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") + parser.add_argument( + "config_path", + type=str, + help="Path to model config file.", + ) + + parser.add_argument( + "config_dataset_path", + type=str, + help="Path to dataset config file.", + ) + parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + + args = parser.parse_args() + + c_dataset = load_config(args.config_dataset_path) + + meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval) + items = meta_data_train + meta_data_eval + + enc_manager = SpeakerManager( + encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda + ) + + compute_encoder_accuracy(items, enc_manager) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py new file mode 100644 index 0000000..c604862 --- /dev/null +++ b/TTS/bin/extract_tts_spectrograms.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +"""Extract Mel spectrograms with teacher forcing.""" + +import argparse +import os + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import TTSDataset, load_tts_samples +from TTS.tts.models import setup_model +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import quantize +from TTS.utils.generic_utils import count_parameters + +use_cuda = torch.cuda.is_available() + + +def setup_loader(ap, r, verbose=False): + tokenizer, _ = TTSTokenizer.init_from_config(c) + dataset = TTSDataset( + outputs_per_step=r, + compute_linear_spec=False, + samples=meta_data, + tokenizer=tokenizer, + ap=ap, + batch_group_size=0, + min_text_len=c.min_text_len, + max_text_len=c.max_text_len, + min_audio_len=c.min_audio_len, + max_audio_len=c.max_audio_len, + phoneme_cache_path=c.phoneme_cache_path, + precompute_num_workers=0, + use_noise_augment=False, + verbose=verbose, + speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None, + d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, + ) + + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.preprocess_samples() + + loader = DataLoader( + dataset, + batch_size=c.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=None, + num_workers=c.num_loader_workers, + pin_memory=False, + ) + return loader + + +def set_filename(wav_path, out_path): + wav_file = os.path.basename(wav_path) + file_name = wav_file.split(".")[0] + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) + wavq_path = os.path.join(out_path, "quant", file_name) + mel_path = os.path.join(out_path, "mel", file_name) + wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav") + wav_path = os.path.join(out_path, "wav", file_name + ".wav") + return file_name, wavq_path, mel_path, wav_gl_path, wav_path + + +def format_data(data): + # setup input data + text_input = data["token_id"] + text_lengths = data["token_id_lengths"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] + item_idx = data["item_idxs"] + d_vectors = data["d_vectors"] + speaker_ids = data["speaker_ids"] + attn_mask = data["attns"] + avg_text_length = torch.mean(text_lengths.float()) + avg_spec_length = torch.mean(mel_lengths.float()) + + # dispatch data to GPU + if use_cuda: + text_input = text_input.cuda(non_blocking=True) + text_lengths = text_lengths.cuda(non_blocking=True) + mel_input = mel_input.cuda(non_blocking=True) + mel_lengths = mel_lengths.cuda(non_blocking=True) + if speaker_ids is not None: + speaker_ids = speaker_ids.cuda(non_blocking=True) + if d_vectors is not None: + d_vectors = d_vectors.cuda(non_blocking=True) + if attn_mask is not None: + attn_mask = attn_mask.cuda(non_blocking=True) + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + avg_text_length, + avg_spec_length, + attn_mask, + item_idx, + ) + + +@torch.no_grad() +def inference( + model_name, + model, + ap, + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids=None, + d_vectors=None, +): + if model_name == "glow_tts": + speaker_c = None + if speaker_ids is not None: + speaker_c = speaker_ids + elif d_vectors is not None: + speaker_c = d_vectors + outputs = model.inference_with_MAS( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, + ) + model_output = outputs["model_outputs"] + model_output = model_output.detach().cpu().numpy() + + elif "tacotron" in model_name: + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input) + postnet_outputs = outputs["model_outputs"] + # normalize tacotron output + if model_name == "tacotron": + mel_specs = [] + postnet_outputs = postnet_outputs.data.cpu().numpy() + for b in range(postnet_outputs.shape[0]): + postnet_output = postnet_outputs[b] + mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T)) + model_output = torch.stack(mel_specs).cpu().numpy() + + elif model_name == "tacotron2": + model_output = postnet_outputs.detach().cpu().numpy() + return model_output + + +def extract_spectrograms( + data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt" +): + model.eval() + export_metadata = [] + for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): + # format data + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + _, + _, + _, + item_idx, + ) = format_data(data) + + model_output = inference( + c.model.lower(), + model, + ap, + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + ) + + for idx in range(text_input.shape[0]): + wav_file_path = item_idx[idx] + wav = ap.load_wav(wav_file_path) + _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) + + # quantize and save wav + if quantize_bits > 0: + wavq = quantize(wav, quantize_bits) + np.save(wavq_path, wavq) + + # save TTS mel + mel = model_output[idx] + mel_length = mel_lengths[idx] + mel = mel[:mel_length, :].T + np.save(mel_path, mel) + + export_metadata.append([wav_file_path, mel_path]) + if save_audio: + ap.save_wav(wav, wav_path) + + if debug: + print("Audio for debug saved at:", wav_gl_path) + wav = ap.inv_melspectrogram(mel) + ap.save_wav(wav, wav_gl_path) + + with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f: + for data in export_metadata: + f.write(f"{data[0]}|{data[1]+'.npy'}\n") + + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data, speaker_manager + + # Audio processor + ap = AudioProcessor(**c.audio) + + # load data instances + meta_data_train, meta_data_eval = load_tts_samples( + c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + + # use eval and training partitions + meta_data = meta_data_train + meta_data_eval + + # init speaker manager + if c.use_speaker_embedding: + speaker_manager = SpeakerManager(data_items=meta_data) + elif c.use_d_vector_file: + speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file) + else: + speaker_manager = None + + # setup model + model = setup_model(c) + + # restore model + model.load_checkpoint(c, args.checkpoint_path, eval=True) + + if use_cuda: + model.cuda() + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + # set r + r = 1 if c.model.lower() == "glow_tts" else model.decoder.r + own_loader = setup_loader(ap, r, verbose=True) + + extract_spectrograms( + own_loader, + model, + ap, + args.output_path, + quantize_bits=args.quantize_bits, + save_audio=args.save_audio, + debug=args.debug, + metada_name="metada.txt", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) + parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) + parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) + parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") + parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") + parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero") + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + args = parser.parse_args() + + c = load_config(args.config_path) + c.audio.trim_silence = False + main(args) diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py new file mode 100644 index 0000000..ea16974 --- /dev/null +++ b/TTS/bin/find_unique_chars.py @@ -0,0 +1,45 @@ +"""Find all the unique characters in a dataset""" +import argparse +from argparse import RawTextHelpFormatter + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples + + +def main(): + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_chars.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + + items = train_items + eval_items + + texts = "".join(item["text"] for item in items) + chars = set(texts) + lower_chars = filter(lambda c: c.islower(), chars) + chars_force_lower = [c.lower() for c in chars] + chars_force_lower = set(chars_force_lower) + + print(f" > Number of unique characters: {len(chars)}") + print(f" > Unique characters: {''.join(sorted(chars))}") + print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") + print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py new file mode 100644 index 0000000..4bd7a78 --- /dev/null +++ b/TTS/bin/find_unique_phonemes.py @@ -0,0 +1,74 @@ +"""Find all the unique characters in a dataset""" +import argparse +import multiprocessing +from argparse import RawTextHelpFormatter + +from tqdm.contrib.concurrent import process_map + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.text.phonemizers import Gruut + + +def compute_phonemes(item): + text = item["text"] + ph = phonemizer.phonemize(text).replace("|", "") + return set(list(ph)) + + +def main(): + # pylint: disable=W0601 + global c, phonemizer + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_phonemes.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + items = train_items + eval_items + print("Num items:", len(items)) + + language_list = [item["language"] for item in items] + is_lang_def = all(language_list) + + if not c.phoneme_language or not is_lang_def: + raise ValueError("Phoneme language must be defined in config.") + + if not language_list.count(language_list[0]) == len(language_list): + raise ValueError( + "Currently, just one phoneme language per config file is supported !! Please split the dataset config into different configs and run it individually for each language !!" + ) + + phonemizer = Gruut(language=language_list[0], keep_puncs=True) + + phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) + phones = [] + for ph in phonemes: + phones.extend(ph) + + phones = set(phones) + lower_phones = filter(lambda c: c.islower(), phones) + phones_force_lower = [c.lower() for c in phones] + phones_force_lower = set(phones_force_lower) + + print(f" > Number of unique phonemes: {len(phones)}") + print(f" > Unique phonemes: {''.join(sorted(phones))}") + print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}") + print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}") + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py new file mode 100644 index 0000000..a1eaf4c --- /dev/null +++ b/TTS/bin/remove_silence_using_vad.py @@ -0,0 +1,124 @@ +import argparse +import glob +import multiprocessing +import os +import pathlib + +import torch +from tqdm import tqdm + +from TTS.utils.vad import get_vad_model_and_utils, remove_silence + +torch.set_num_threads(1) + + +def adjust_path_and_remove_silence(audio_path): + output_path = audio_path.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) + # ignore if the file exists + if os.path.exists(output_path) and not args.force: + return output_path, False + + # create all directory structure + pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) + # remove the silence and save the audio + output_path, is_speech = remove_silence( + model_and_utils, + audio_path, + output_path, + trim_just_beginning_and_end=args.trim_just_beginning_and_end, + use_cuda=args.use_cuda, + ) + return output_path, is_speech + + +def preprocess_audios(): + files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True)) + print("> Number of files: ", len(files)) + if not args.force: + print("> Ignoring files that already exist in the output idrectory.") + + if args.trim_just_beginning_and_end: + print("> Trimming just the beginning and the end with nonspeech parts.") + else: + print("> Trimming all nonspeech parts.") + + filtered_files = [] + if files: + # create threads + # num_threads = multiprocessing.cpu_count() + # process_map(adjust_path_and_remove_silence, files, max_workers=num_threads, chunksize=15) + + if args.num_processes > 1: + with multiprocessing.Pool(processes=args.num_processes) as pool: + results = list( + tqdm( + pool.imap_unordered(adjust_path_and_remove_silence, files), + total=len(files), + desc="Processing audio files", + ) + ) + for output_path, is_speech in results: + if not is_speech: + filtered_files.append(output_path) + else: + for f in tqdm(files): + output_path, is_speech = adjust_path_and_remove_silence(f) + if not is_speech: + filtered_files.append(output_path) + + # write files that do not have speech + with open(os.path.join(args.output_dir, "filtered_files.txt"), "w", encoding="utf-8") as f: + for file in filtered_files: + f.write(str(file) + "\n") + else: + print("> No files Found !") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True" + ) + parser.add_argument("-i", "--input_dir", type=str, help="Dataset root dir", required=True) + parser.add_argument("-o", "--output_dir", type=str, help="Output Dataset dir", default="") + parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files") + parser.add_argument( + "-g", + "--glob", + type=str, + default="**/*.wav", + help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav", + ) + parser.add_argument( + "-t", + "--trim_just_beginning_and_end", + type=bool, + default=True, + help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trim. Default True", + ) + parser.add_argument( + "-c", + "--use_cuda", + type=bool, + default=False, + help="If True use cuda", + ) + parser.add_argument( + "--use_onnx", + type=bool, + default=False, + help="If True use onnx", + ) + parser.add_argument( + "--num_processes", + type=int, + default=1, + help="Number of processes to use", + ) + args = parser.parse_args() + + if args.output_dir == "": + args.output_dir = args.input_dir + + # load the model and utils + model_and_utils = get_vad_model_and_utils(use_cuda=args.use_cuda, use_onnx=args.use_onnx) + preprocess_audios() diff --git a/TTS/bin/resample.py b/TTS/bin/resample.py new file mode 100644 index 0000000..a3f2848 --- /dev/null +++ b/TTS/bin/resample.py @@ -0,0 +1,90 @@ +import argparse +import glob +import os +from argparse import RawTextHelpFormatter +from multiprocessing import Pool +from shutil import copytree + +import librosa +import soundfile as sf +from tqdm import tqdm + + +def resample_file(func_args): + filename, output_sr = func_args + y, sr = librosa.load(filename, sr=output_sr) + sf.write(filename, y, sr) + + +def resample_files(input_dir, output_sr, output_dir=None, file_ext="wav", n_jobs=10): + if output_dir: + print("Recursively copying the input folder...") + copytree(input_dir, output_dir) + input_dir = output_dir + + print("Resampling the audio files...") + audio_files = glob.glob(os.path.join(input_dir, f"**/*.{file_ext}"), recursive=True) + print(f"Found {len(audio_files)} files...") + audio_files = list(zip(audio_files, len(audio_files) * [output_sr])) + with Pool(processes=n_jobs) as p: + with tqdm(total=len(audio_files)) as pbar: + for _, _ in enumerate(p.imap_unordered(resample_file, audio_files)): + pbar.update() + + print("Done !") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Resample a folder recusively with librosa + Can be used in place or create a copy of the folder as an output.\n\n + Example run: + python TTS/bin/resample.py + --input_dir /root/LJSpeech-1.1/ + --output_sr 22050 + --output_dir /root/resampled_LJSpeech-1.1/ + --file_ext wav + --n_jobs 24 + """, + formatter_class=RawTextHelpFormatter, + ) + + parser.add_argument( + "--input_dir", + type=str, + default=None, + required=True, + help="Path of the folder containing the audio files to resample", + ) + + parser.add_argument( + "--output_sr", + type=int, + default=22050, + required=False, + help="Samlple rate to which the audio files should be resampled", + ) + + parser.add_argument( + "--output_dir", + type=str, + default=None, + required=False, + help="Path of the destination folder. If not defined, the operation is done in place", + ) + + parser.add_argument( + "--file_ext", + type=str, + default="wav", + required=False, + help="Extension of the audio files to resample", + ) + + parser.add_argument( + "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores" + ) + + args = parser.parse_args() + + resample_files(args.input_dir, args.output_sr, args.output_dir, args.file_ext, args.n_jobs) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py new file mode 100644 index 0000000..b86252a --- /dev/null +++ b/TTS/bin/synthesize.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import contextlib +import sys +from argparse import RawTextHelpFormatter + +# pylint: disable=redefined-outer-name, unused-argument +from pathlib import Path + +description = """ +Synthesize speech on command line. + +You can either use your trained model or choose a model from the provided list. + +If you don't specify any models, then it uses LJSpeech based English model. + +#### Single Speaker Models + +- List provided models: + + ``` + $ tts --list_models + ``` + +- Get model info (for both tts_models and vocoder_models): + + - Query by type/name: + The model_info_by_name uses the name as it from the --list_models. + ``` + $ tts --model_info_by_name "///" + ``` + For example: + ``` + $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts + $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2 + ``` + - Query by type/idx: + The model_query_idx uses the corresponding idx from --list_models. + + ``` + $ tts --model_info_by_idx "/" + ``` + + For example: + + ``` + $ tts --model_info_by_idx tts_models/3 + ``` + + - Query info for model info by full name: + ``` + $ tts --model_info_by_name "///" + ``` + +- Run TTS with default models: + + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav + ``` + +- Run TTS and pipe out the generated TTS wav file data: + + ``` + $ tts --text "Text for TTS" --pipe_out --out_path output/path/speech.wav | aplay + ``` + +- Run a TTS model with its default vocoder model: + + ``` + $ tts --text "Text for TTS" --model_name "///" --out_path output/path/speech.wav + ``` + + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav + ``` + +- Run with specific TTS and vocoder models from the list: + + ``` + $ tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --out_path output/path/speech.wav + ``` + + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav + ``` + +- Run your own TTS model (Using Griffin-Lim Vocoder): + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + ``` + +- Run your own TTS and Vocoder models: + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json + ``` + +#### Multi-speaker Models + +- List the available speakers and choose a among them: + + ``` + $ tts --model_name "//" --list_speaker_idxs + ``` + +- Run the multi-speaker TTS model with the target speaker ID: + + ``` + $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx + ``` + +- Run your own multi-speaker TTS model: + + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx + ``` + +### Voice Conversion Models + +``` +$ tts --out_path output/path/speech.wav --model_name "//" --source_wav --target_wav +``` +""" + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + if v.lower() in ("no", "false", "f", "n", "0"): + return False + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def main(): + parser = argparse.ArgumentParser( + description=description.replace(" ```\n", ""), + formatter_class=RawTextHelpFormatter, + ) + + parser.add_argument( + "--list_models", + type=str2bool, + nargs="?", + const=True, + default=False, + help="list available pre-trained TTS and vocoder models.", + ) + + parser.add_argument( + "--model_info_by_idx", + type=str, + default=None, + help="model info using query format: /", + ) + + parser.add_argument( + "--model_info_by_name", + type=str, + default=None, + help="model info using query format: ///", + ) + + parser.add_argument("--text", type=str, default=None, help="Text to generate speech.") + + # Args for running pre-trained TTS models. + parser.add_argument( + "--model_name", + type=str, + default="tts_models/en/ljspeech/tacotron2-DDC", + help="Name of one of the pre-trained TTS models in format //", + ) + parser.add_argument( + "--vocoder_name", + type=str, + default=None, + help="Name of one of the pre-trained vocoder models in format //", + ) + + # Args for running custom models + parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Path to model file.", + ) + parser.add_argument( + "--out_path", + type=str, + default="tts_output.wav", + help="Output wav file path.", + ) + parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False) + parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu") + parser.add_argument( + "--vocoder_path", + type=str, + help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", + default=None, + ) + parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) + parser.add_argument( + "--encoder_path", + type=str, + help="Path to speaker encoder model file.", + default=None, + ) + parser.add_argument("--encoder_config_path", type=str, help="Path to speaker encoder config file.", default=None) + parser.add_argument( + "--pipe_out", + help="stdout the generated TTS wav file for shell pipe.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) + + # args for multi-speaker synthesis + parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) + parser.add_argument( + "--speaker_idx", + type=str, + help="Target speaker ID for a multi-speaker TTS model.", + default=None, + ) + parser.add_argument( + "--language_idx", + type=str, + help="Target language ID for a multi-lingual TTS model.", + default=None, + ) + parser.add_argument( + "--speaker_wav", + nargs="+", + help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.", + default=None, + ) + parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None) + parser.add_argument( + "--capacitron_style_wav", type=str, help="Wav path file for Capacitron prosody reference.", default=None + ) + parser.add_argument("--capacitron_style_text", type=str, help="Transcription of the reference.", default=None) + parser.add_argument( + "--list_speaker_idxs", + help="List available speaker ids for the defined multi-speaker model.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) + parser.add_argument( + "--list_language_idxs", + help="List available language ids for the defined multi-lingual model.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) + # aux args + parser.add_argument( + "--save_spectogram", + type=bool, + help="If true save raw spectogram for further (vocoder) processing in out_path.", + default=False, + ) + parser.add_argument( + "--reference_wav", + type=str, + help="Reference wav file to convert in the voice of the speaker_idx or speaker_wav", + default=None, + ) + parser.add_argument( + "--reference_speaker_idx", + type=str, + help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).", + default=None, + ) + parser.add_argument( + "--progress_bar", + type=str2bool, + help="If true shows a progress bar for the model download. Defaults to True", + default=True, + ) + + # voice conversion args + parser.add_argument( + "--source_wav", + type=str, + default=None, + help="Original audio file to convert in the voice of the target_wav", + ) + parser.add_argument( + "--target_wav", + type=str, + default=None, + help="Target audio file to convert in the voice of the source_wav", + ) + + parser.add_argument( + "--voice_dir", + type=str, + default=None, + help="Voice dir for tortoise model", + ) + + args = parser.parse_args() + + # print the description if either text or list_models is not set + check_args = [ + args.text, + args.list_models, + args.list_speaker_idxs, + args.list_language_idxs, + args.reference_wav, + args.model_info_by_idx, + args.model_info_by_name, + args.source_wav, + args.target_wav, + ] + if not any(check_args): + parser.parse_args(["-h"]) + + pipe_out = sys.stdout if args.pipe_out else None + + with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout): + # Late-import to make things load faster + from TTS.api import TTS + from TTS.utils.manage import ModelManager + from TTS.utils.synthesizer import Synthesizer + + # load model manager + path = Path(__file__).parent / "../.models.json" + manager = ModelManager(path, progress_bar=args.progress_bar) + api = TTS() + + tts_path = None + tts_config_path = None + speakers_file_path = None + language_ids_file_path = None + vocoder_path = None + vocoder_config_path = None + encoder_path = None + encoder_config_path = None + vc_path = None + vc_config_path = None + model_dir = None + + # CASE1 #list : list pre-trained TTS models + if args.list_models: + manager.list_models() + sys.exit() + + # CASE2 #info : model info for pre-trained TTS models + if args.model_info_by_idx: + model_query = args.model_info_by_idx + manager.model_info_by_idx(model_query) + sys.exit() + + if args.model_info_by_name: + model_query_full_name = args.model_info_by_name + manager.model_info_by_full_name(model_query_full_name) + sys.exit() + + # CASE3: load pre-trained model paths + if args.model_name is not None and not args.model_path: + model_path, config_path, model_item = manager.download_model(args.model_name) + # tts model + if model_item["model_type"] == "tts_models": + tts_path = model_path + tts_config_path = config_path + if "default_vocoder" in model_item: + args.vocoder_name = ( + model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + ) + + # voice conversion model + if model_item["model_type"] == "voice_conversion_models": + vc_path = model_path + vc_config_path = config_path + + # tts model with multiple files to be loaded from the directory path + if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list): + model_dir = model_path + tts_path = None + tts_config_path = None + args.vocoder_name = None + + # load vocoder + if args.vocoder_name is not None and not args.vocoder_path: + vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) + + # CASE4: set custom model paths + if args.model_path is not None: + tts_path = args.model_path + tts_config_path = args.config_path + speakers_file_path = args.speakers_file_path + language_ids_file_path = args.language_ids_file_path + + if args.vocoder_path is not None: + vocoder_path = args.vocoder_path + vocoder_config_path = args.vocoder_config_path + + if args.encoder_path is not None: + encoder_path = args.encoder_path + encoder_config_path = args.encoder_config_path + + device = args.device + if args.use_cuda: + device = "cuda" + + # load models + synthesizer = Synthesizer( + tts_path, + tts_config_path, + speakers_file_path, + language_ids_file_path, + vocoder_path, + vocoder_config_path, + encoder_path, + encoder_config_path, + vc_path, + vc_config_path, + model_dir, + args.voice_dir, + ).to(device) + + # query speaker ids of a multi-speaker model. + if args.list_speaker_idxs: + print( + " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." + ) + print(synthesizer.tts_model.speaker_manager.name_to_id) + return + + # query langauge ids of a multi-lingual model. + if args.list_language_idxs: + print( + " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." + ) + print(synthesizer.tts_model.language_manager.name_to_id) + return + + # check the arguments against a multi-speaker model. + if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): + print( + " [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to " + "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." + ) + return + + # RUN THE SYNTHESIS + if args.text: + print(" > Text: {}".format(args.text)) + + # kick it + if tts_path is not None: + wav = synthesizer.tts( + args.text, + speaker_name=args.speaker_idx, + language_name=args.language_idx, + speaker_wav=args.speaker_wav, + reference_wav=args.reference_wav, + style_wav=args.capacitron_style_wav, + style_text=args.capacitron_style_text, + reference_speaker_name=args.reference_speaker_idx, + ) + elif vc_path is not None: + wav = synthesizer.voice_conversion( + source_wav=args.source_wav, + target_wav=args.target_wav, + ) + elif model_dir is not None: + wav = synthesizer.tts( + args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav + ) + + # save the results + print(" > Saving output to {}".format(args.out_path)) + synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out) + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py new file mode 100644 index 0000000..a32ad00 --- /dev/null +++ b/TTS/bin/train_encoder.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import sys +import time +import traceback + +import torch +from torch.utils.data import DataLoader +from trainer.io import copy_model_files, save_best_model, save_checkpoint +from trainer.torch import NoamLR +from trainer.trainer_utils import get_optimizer + +from TTS.encoder.dataset import EncoderDataset +from TTS.encoder.utils.generic_utils import setup_encoder_model +from TTS.encoder.utils.training import init_training +from TTS.encoder.utils.visual import plot_embeddings +from TTS.tts.datasets import load_tts_samples +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import count_parameters, remove_experiment_folder +from TTS.utils.samplers import PerfectBatchSampler +from TTS.utils.training import check_update + +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True +torch.manual_seed(54321) +use_cuda = torch.cuda.is_available() +num_gpus = torch.cuda.device_count() +print(" > Using CUDA: ", use_cuda) +print(" > Number of GPUs: ", num_gpus) + + +def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): + num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class + num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch + + dataset = EncoderDataset( + c, + ap, + meta_data_eval if is_val else meta_data_train, + voice_len=c.voice_len, + num_utter_per_class=num_utter_per_class, + num_classes_in_batch=num_classes_in_batch, + verbose=verbose, + augmentation_config=c.audio_augmentation if not is_val else None, + use_torch_spec=c.model_params.get("use_torch_spec", False), + ) + # get classes list + classes = dataset.get_class_list() + + sampler = PerfectBatchSampler( + dataset.items, + classes, + batch_size=num_classes_in_batch * num_utter_per_class, # total batch size + num_classes_in_batch=num_classes_in_batch, + num_gpus=1, + shuffle=not is_val, + drop_last=True, + ) + + if len(classes) < num_classes_in_batch: + if is_val: + raise RuntimeError( + f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !" + ) + raise RuntimeError( + f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !" + ) + + # set the classes to avoid get wrong class_id when the number of training and eval classes are not equal + if is_val: + dataset.set_classes(train_classes) + + loader = DataLoader( + dataset, + num_workers=c.num_loader_workers, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + ) + + return loader, classes, dataset.get_map_classid_to_classname() + + +def evaluation(model, criterion, data_loader, global_step): + eval_loss = 0 + for _, data in enumerate(data_loader): + with torch.no_grad(): + # setup input data + inputs, labels = data + + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose( + labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1 + ).reshape(labels.shape) + inputs = torch.transpose( + inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1 + ).reshape(inputs.shape) + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion( + outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels + ) + + eval_loss += loss.item() + + eval_avg_loss = eval_loss / len(data_loader) + # save stats + dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss}) + # plot the last batch in the evaluation + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.eval_figures(global_step, figures) + return eval_avg_loss + + +def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step): + model.train() + best_loss = {"train_loss": None, "eval_loss": float("inf")} + avg_loader_time = 0 + end_time = time.time() + for epoch in range(c.epochs): + tot_loss = 0 + epoch_time = 0 + for _, data in enumerate(data_loader): + start_time = time.time() + + # setup input data + inputs, labels = data + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape( + labels.shape + ) + inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape( + inputs.shape + ) + # ToDo: move it to a unit test + # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) + # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) + # idx = 0 + # for j in range(0, c.num_classes_in_batch, 1): + # for i in range(j, len(labels), c.num_classes_in_batch): + # if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): + # print("Invalid") + # print(labels) + # exit() + # idx += 1 + # labels = labels_converted + # inputs = inputs_converted + + loader_time = time.time() - end_time + global_step += 1 + + # setup lr + if c.lr_decay: + scheduler.step() + optimizer.zero_grad() + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion( + outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels + ) + loss.backward() + grad_norm, _ = check_update(model, c.grad_clip) + optimizer.step() + + step_time = time.time() - start_time + epoch_time += step_time + + # acumulate the total epoch loss + tot_loss += loss.item() + + # Averaged Loader Time + num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1 + avg_loader_time = ( + 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time + if avg_loader_time != 0 + else loader_time + ) + current_lr = optimizer.param_groups[0]["lr"] + + if global_step % c.steps_plot_stats == 0: + # Plot Training Epoch Stats + train_stats = { + "loss": loss.item(), + "lr": current_lr, + "grad_norm": grad_norm, + "step_time": step_time, + "avg_loader_time": avg_loader_time, + } + dashboard_logger.train_epoch_stats(global_step, train_stats) + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.train_figures(global_step, figures) + + if global_step % c.print_step == 0: + print( + " | > Step:{} Loss:{:.5f} GradNorm:{:.5f} " + "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( + global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr + ), + flush=True, + ) + + if global_step % c.save_step == 0: + # save model + save_checkpoint( + c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict() + ) + + end_time = time.time() + + print("") + print( + ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} " + "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format( + epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time + ), + flush=True, + ) + # evaluation + if c.run_eval: + model.eval() + eval_loss = evaluation(model, criterion, eval_data_loader, global_step) + print("\n\n") + print("--> EVAL PERFORMANCE") + print( + " | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss), + flush=True, + ) + # save the best checkpoint + best_loss = save_best_model( + {"train_loss": None, "eval_loss": eval_loss}, + best_loss, + c, + model, + optimizer, + None, + global_step, + epoch, + OUT_PATH, + criterion=criterion.state_dict(), + ) + model.train() + + return best_loss, global_step + + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data_train + global meta_data_eval + global train_classes + + ap = AudioProcessor(**c.audio) + model = setup_encoder_model(c) + + optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model) + + # pylint: disable=redefined-outer-name + meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) + + train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) + if c.run_eval: + eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) + else: + eval_data_loader = None + + num_classes = len(train_classes) + criterion = model.get_criterion(c, num_classes) + + if c.loss == "softmaxproto" and c.model != "speaker_encoder": + c.map_classid_to_classname = map_classid_to_classname + copy_model_files(c, OUT_PATH, new_fields={}) + + if args.restore_path: + criterion, args.restore_step = model.load_checkpoint( + c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion + ) + print(" > Model restored from step %d" % args.restore_step, flush=True) + else: + args.restore_step = 0 + + if c.lr_decay: + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) + else: + scheduler = None + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + + if use_cuda: + model = model.cuda() + criterion.cuda() + + global_step = args.restore_step + _, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step) + + +if __name__ == "__main__": + args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training() + + try: + main(args) + except KeyboardInterrupt: + remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except + remove_experiment_folder(OUT_PATH) + traceback.print_exc() + sys.exit(1) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py new file mode 100644 index 0000000..bdb4f6f --- /dev/null +++ b/TTS/bin/train_tts.py @@ -0,0 +1,71 @@ +import os +from dataclasses import dataclass, field + +from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models import setup_model + + +@dataclass +class TrainTTSArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def main(): + """Run `tts` model training directly by a `config.json` file.""" + # init trainer args + train_args = TrainTTSArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + train_samples, eval_samples = load_tts_samples( + config.datasets, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the model from config + model = setup_model(config, train_samples + eval_samples) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + model.config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + parse_command_line_args=False, + ) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py new file mode 100644 index 0000000..32ecd7b --- /dev/null +++ b/TTS/bin/train_vocoder.py @@ -0,0 +1,77 @@ +import os +from dataclasses import dataclass, field + +from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.vocoder.models import setup_model + + +@dataclass +class TrainVocoderArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def main(): + """Run `tts` model training directly by a `config.json` file.""" + # init trainer args + train_args = TrainVocoderArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + if "feature_path" in config and config.feature_path: + # load pre-computed features + print(f" > Loading features from: {config.feature_path}") + eval_samples, train_samples = load_wav_feat_data(config.data_path, config.feature_path, config.eval_split_size) + else: + # load data raw wav files + eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + + # setup audio processor + ap = AudioProcessor(**config.audio) + + # init the model from config + model = setup_model(config) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, + parse_command_line_args=False, + ) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py new file mode 100644 index 0000000..09582ce --- /dev/null +++ b/TTS/bin/tune_wavegrad.py @@ -0,0 +1,103 @@ +"""Search a good noise schedule for WaveGrad for a given number of inference iterations""" +import argparse +from itertools import product as cartesian_product + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.models import setup_model + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") + parser.add_argument("--config_path", type=str, help="Path to model config file.") + parser.add_argument("--data_path", type=str, help="Path to data directory.") + parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") + parser.add_argument( + "--num_iter", + type=int, + help="Number of model inference iterations that you like to optimize noise schedule for.", + ) + parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.") + parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") + parser.add_argument( + "--search_depth", + type=int, + default=3, + help="Search granularity. Increasing this increases the run-time exponentially.", + ) + + # load config + args = parser.parse_args() + config = load_config(args.config_path) + + # setup audio processor + ap = AudioProcessor(**config.audio) + + # load dataset + _, train_data = load_wav_data(args.data_path, 0) + train_data = train_data[: args.num_samples] + dataset = WaveGradDataset( + ap=ap, + items=train_data, + seq_len=-1, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=True, + return_segments=False, + use_noise_augment=False, + use_cache=False, + verbose=True, + ) + loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_full_clips, + drop_last=False, + num_workers=config.num_loader_workers, + pin_memory=False, + ) + + # setup the model + model = setup_model(config) + if args.use_cuda: + model.cuda() + + # setup optimization parameters + base_values = sorted(10 * np.random.uniform(size=args.search_depth)) + print(f" > base values: {base_values}") + exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) + best_error = float("inf") + best_schedule = None # pylint: disable=C0103 + total_search_iter = len(base_values) ** args.num_iter + for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): + beta = exponents * base + model.compute_noise_level(beta) + for data in loader: + mel, audio = data + y_hat = model.inference(mel.cuda() if args.use_cuda else mel) + + if args.use_cuda: + y_hat = y_hat.cpu() + y_hat = y_hat.numpy() + + mel_hat = [] + for i in range(y_hat.shape[0]): + m = ap.melspectrogram(y_hat[i, 0])[:, :-1] + mel_hat.append(torch.from_numpy(m)) + + mel_hat = torch.stack(mel_hat) + mse = torch.sum((mel - mel_hat) ** 2).mean() + if mse.item() < best_error: + best_error = mse.item() + best_schedule = {"beta": beta} + print(f" > Found a better schedule. - MSE: {mse.item()}") + np.save(args.output_path, best_schedule) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py new file mode 100644 index 0000000..c5a6dd6 --- /dev/null +++ b/TTS/config/__init__.py @@ -0,0 +1,135 @@ +import json +import os +import re +from typing import Dict + +import fsspec +import yaml +from coqpit import Coqpit + +from TTS.config.shared_configs import * +from TTS.utils.generic_utils import find_module + + +def read_json_with_comments(json_path): + """for backward compat.""" + # fallback to json + with fsspec.open(json_path, "r", encoding="utf-8") as f: + input_str = f.read() + # handle comments but not urls with // + input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str) + return json.loads(input_str) + +def register_config(model_name: str) -> Coqpit: + """Find the right config for the given model name. + + Args: + model_name (str): Model name. + + Raises: + ModuleNotFoundError: No matching config for the model name. + + Returns: + Coqpit: config class. + """ + config_class = None + config_name = model_name + "_config" + + # TODO: fix this + if model_name == "xtts": + from TTS.tts.configs.xtts_config import XttsConfig + + config_class = XttsConfig + paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"] + for path in paths: + try: + config_class = find_module(path, config_name) + except ModuleNotFoundError: + pass + if config_class is None: + raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.") + return config_class + + +def _process_model_name(config_dict: Dict) -> str: + """Format the model name as expected. It is a band-aid for the old `vocoder` model names. + + Args: + config_dict (Dict): A dictionary including the config fields. + + Returns: + str: Formatted modelname. + """ + model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"] + model_name = model_name.replace("_generator", "").replace("_discriminator", "") + return model_name + + +def load_config(config_path: str) -> Coqpit: + """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name + to find the corresponding Config class. Then initialize the Config. + + Args: + config_path (str): path to the config file. + + Raises: + TypeError: given config file has an unknown type. + + Returns: + Coqpit: TTS config object. + """ + config_dict = {} + ext = os.path.splitext(config_path)[1] + if ext in (".yml", ".yaml"): + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + elif ext == ".json": + try: + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = json.load(f) + except json.decoder.JSONDecodeError: + # backwards compat. + data = read_json_with_comments(config_path) + else: + raise TypeError(f" [!] Unknown config file type {ext}") + config_dict.update(data) + model_name = _process_model_name(config_dict) + config_class = register_config(model_name.lower()) + config = config_class() + config.from_dict(config_dict) + return config + + +def check_config_and_model_args(config, arg_name, value): + """Check the give argument in `config.model_args` if exist or in `config` for + the given value. + + Return False if the argument does not exist in `config.model_args` or `config`. + This is to patch up the compatibility between models with and without `model_args`. + + TODO: Remove this in the future with a unified approach. + """ + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] == value + if hasattr(config, arg_name): + return config[arg_name] == value + return False + + +def get_from_config_or_model_args(config, arg_name): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + return config[arg_name] + + +def get_from_config_or_model_args_with_default(config, arg_name, def_val): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + if hasattr(config, arg_name): + return config[arg_name] + return def_val diff --git a/TTS/config/__pycache__/__init__.cpython-311.pyc b/TTS/config/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..f244eac Binary files /dev/null and b/TTS/config/__pycache__/__init__.cpython-311.pyc differ diff --git a/TTS/config/__pycache__/shared_configs.cpython-311.pyc b/TTS/config/__pycache__/shared_configs.cpython-311.pyc new file mode 100644 index 0000000..4f55e59 Binary files /dev/null and b/TTS/config/__pycache__/shared_configs.cpython-311.pyc differ diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py new file mode 100644 index 0000000..7fae77d --- /dev/null +++ b/TTS/config/shared_configs.py @@ -0,0 +1,268 @@ +from dataclasses import asdict, dataclass +from typing import List + +from coqpit import Coqpit, check_argument +from trainer import TrainerConfig + + +@dataclass +class BaseAudioConfig(Coqpit): + """Base config to definge audio processing parameters. It is used to initialize + ```TTS.utils.audio.AudioProcessor.``` + + Args: + fft_size (int): + Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + + win_length (int): + Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match + ```fft_size```. Defaults to 1024. + + hop_length (int): + Number of audio samples between adjacent STFT columns. Defaults to 1024. + + frame_shift_ms (int): + Set ```hop_length``` based on milliseconds and sampling rate. + + frame_length_ms (int): + Set ```win_length``` based on milliseconds and sampling rate. + + stft_pad_mode (str): + Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. + + sample_rate (int): + Audio sampling rate. Defaults to 22050. + + resample (bool): + Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + + preemphasis (float): + Preemphasis coefficient. Defaults to 0.0. + + ref_level_db (int): 20 + Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. + Defaults to 20. + + do_sound_norm (bool): + Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + + log_func (str): + Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. + + do_trim_silence (bool): + Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + pitch_fmax (float, optional): + Maximum frequency of the F0 frames. Defaults to ```640```. + + pitch_fmin (float, optional): + Minimum frequency of the F0 frames. Defaults to ```1```. + + trim_db (int): + Silence threshold used for silence trimming. Defaults to 45. + + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + + power (float): + Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the + artifacts in the synthesized voice. Defaults to 1.5. + + griffin_lim_iters (int): + Number of Griffing Lim iterations. Defaults to 60. + + num_mels (int): + Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. + It needs to be adjusted for a dataset. Defaults to 0. + + mel_fmax (float): + Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + + spec_gain (int): + Gain applied when converting amplitude to DB. Defaults to 20. + + signal_norm (bool): + enable/disable signal normalization. Defaults to True. + + min_level_db (int): + minimum db threshold for the computed melspectrograms. Defaults to -100. + + symmetric_norm (bool): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else + [0, k], Defaults to True. + + max_norm (float): + ```k``` defining the normalization range. Defaults to 4.0. + + clip_norm (bool): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + stats_path (str): + Path to the computed stats file. Defaults to None. + """ + + # stft parameters + fft_size: int = 1024 + win_length: int = 1024 + hop_length: int = 256 + frame_shift_ms: int = None + frame_length_ms: int = None + stft_pad_mode: str = "reflect" + # audio processing parameters + sample_rate: int = 22050 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + # silence trimming + do_trim_silence: bool = True + trim_db: int = 45 + # rms volume normalization + do_rms_norm: bool = False + db_level: float = None + # griffin-lim params + power: float = 1.5 + griffin_lim_iters: int = 60 + # mel-spec params + num_mels: int = 80 + mel_fmin: float = 0.0 + mel_fmax: float = None + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + # f0 params + pitch_fmax: float = 640.0 + pitch_fmin: float = 1.0 + # normalization params + signal_norm: bool = True + min_level_db: int = -100 + symmetric_norm: bool = True + max_norm: float = 4.0 + clip_norm: bool = True + stats_path: str = None + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056) + check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058) + check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c, + restricted=True, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length") + check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1) + check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10) + check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000) + check_argument("power", c, restricted=True, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000) + + # normalization parameters + check_argument("signal_norm", c, restricted=True) + check_argument("symmetric_norm", c, restricted=True) + check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000) + check_argument("clip_norm", c, restricted=True) + check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True) + check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100) + check_argument("do_trim_silence", c, restricted=True) + check_argument("trim_db", c, restricted=True) + + +@dataclass +class BaseDatasetConfig(Coqpit): + """Base config for TTS datasets. + + Args: + formatter (str): + Formatter name that defines used formatter in ```TTS.tts.datasets.formatter```. Defaults to `""`. + + dataset_name (str): + Unique name for the dataset. Defaults to `""`. + + path (str): + Root path to the dataset files. Defaults to `""`. + + meta_file_train (str): + Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. + Defaults to `""`. + + ignored_speakers (List): + List of speakers IDs that are not used at the training. Default None. + + language (str): + Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`. + + phonemizer (str): + Phonemizer used for that dataset's language. By default it uses `DEF_LANG_TO_PHONEMIZER`. Defaults to `""`. + + meta_file_val (str): + Name of the dataset meta file that defines the instances used at validation. + + meta_file_attn_mask (str): + Path to the file that lists the attention mask files used with models that require attention masks to + train the duration predictor. + """ + + formatter: str = "" + dataset_name: str = "" + path: str = "" + meta_file_train: str = "" + ignored_speakers: List[str] = None + language: str = "" + phonemizer: str = "" + meta_file_val: str = "" + meta_file_attn_mask: str = "" + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("formatter", c, restricted=True) + check_argument("path", c, restricted=True) + check_argument("meta_file_train", c, restricted=True) + check_argument("meta_file_val", c, restricted=False) + check_argument("meta_file_attn_mask", c, restricted=False) + + +@dataclass +class BaseTrainingConfig(TrainerConfig): + """Base config to define the basic 🐸TTS training parameters that are shared + among all the models. It is based on ```Trainer.TrainingConfig```. + + Args: + model (str): + Name of the model that is used in the training. + + num_loader_workers (int): + Number of workers for training time dataloader. + + num_eval_loader_workers (int): + Number of workers for evaluation time dataloader. + """ + + model: str = None + # dataloading + num_loader_workers: int = 0 + num_eval_loader_workers: int = 0 + use_noise_augment: bool = False diff --git a/TTS/encoder/README.md b/TTS/encoder/README.md new file mode 100644 index 0000000..b38b200 --- /dev/null +++ b/TTS/encoder/README.md @@ -0,0 +1,18 @@ +### Speaker Encoder + +This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding. + +With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart. + +Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q). + +![](umap.png) + +Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page. + +To run the code, you need to follow the same flow as in TTS. + +- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model. +- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360``` +- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files. +- Watch training on Tensorboard as in TTS diff --git a/TTS/encoder/__init__.py b/TTS/encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TTS/encoder/__pycache__/__init__.cpython-311.pyc b/TTS/encoder/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..ad49a98 Binary files /dev/null and b/TTS/encoder/__pycache__/__init__.cpython-311.pyc differ diff --git a/TTS/encoder/__pycache__/losses.cpython-311.pyc b/TTS/encoder/__pycache__/losses.cpython-311.pyc new file mode 100644 index 0000000..fce5534 Binary files /dev/null and b/TTS/encoder/__pycache__/losses.cpython-311.pyc differ diff --git a/TTS/encoder/configs/base_encoder_config.py b/TTS/encoder/configs/base_encoder_config.py new file mode 100644 index 0000000..ebbaa04 --- /dev/null +++ b/TTS/encoder/configs/base_encoder_config.py @@ -0,0 +1,61 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import MISSING + +from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class BaseEncoderConfig(BaseTrainingConfig): + """Defines parameters for a Generic Encoder model.""" + + model: str = None + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # model params + model_params: Dict = field( + default_factory=lambda: { + "model_name": "lstm", + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": True, + } + ) + + audio_augmentation: Dict = field(default_factory=lambda: {}) + + # training params + epochs: int = 10000 + loss: str = "angleproto" + grad_clip: float = 3.0 + lr: float = 0.0001 + optimizer: str = "radam" + optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0}) + lr_decay: bool = False + warmup_steps: int = 4000 + + # logging params + tb_model_param_stats: bool = False + steps_plot_stats: int = 10 + save_step: int = 1000 + print_step: int = 20 + run_eval: bool = False + + # data loader + num_classes_in_batch: int = MISSING + num_utter_per_class: int = MISSING + eval_num_classes_in_batch: int = None + eval_num_utter_per_class: int = None + + num_loader_workers: int = MISSING + voice_len: float = 1.6 + + def check_values(self): + super().check_values() + c = asdict(self) + assert ( + c["model_params"]["input_dim"] == self.audio.num_mels + ), " [!] model input dimendion must be equal to melspectrogram dimension." diff --git a/TTS/encoder/configs/emotion_encoder_config.py b/TTS/encoder/configs/emotion_encoder_config.py new file mode 100644 index 0000000..5eda267 --- /dev/null +++ b/TTS/encoder/configs/emotion_encoder_config.py @@ -0,0 +1,12 @@ +from dataclasses import asdict, dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class EmotionEncoderConfig(BaseEncoderConfig): + """Defines parameters for Emotion Encoder model.""" + + model: str = "emotion_encoder" + map_classid_to_classname: dict = None + class_name_key: str = "emotion_name" diff --git a/TTS/encoder/configs/speaker_encoder_config.py b/TTS/encoder/configs/speaker_encoder_config.py new file mode 100644 index 0000000..6dceb00 --- /dev/null +++ b/TTS/encoder/configs/speaker_encoder_config.py @@ -0,0 +1,11 @@ +from dataclasses import asdict, dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class SpeakerEncoderConfig(BaseEncoderConfig): + """Defines parameters for Speaker Encoder model.""" + + model: str = "speaker_encoder" + class_name_key: str = "speaker_name" diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py new file mode 100644 index 0000000..582b1fe --- /dev/null +++ b/TTS/encoder/dataset.py @@ -0,0 +1,147 @@ +import random + +import torch +from torch.utils.data import Dataset + +from TTS.encoder.utils.generic_utils import AugmentWAV + + +class EncoderDataset(Dataset): + def __init__( + self, + config, + ap, + meta_data, + voice_len=1.6, + num_classes_in_batch=64, + num_utter_per_class=10, + verbose=False, + augmentation_config=None, + use_torch_spec=None, + ): + """ + Args: + ap (TTS.tts.utils.AudioProcessor): audio processor object. + meta_data (list): list of dataset instances. + seq_len (int): voice segment length in seconds. + verbose (bool): print diagnostic information. + """ + super().__init__() + self.config = config + self.items = meta_data + self.sample_rate = ap.sample_rate + self.seq_len = int(voice_len * self.sample_rate) + self.num_utter_per_class = num_utter_per_class + self.ap = ap + self.verbose = verbose + self.use_torch_spec = use_torch_spec + self.classes, self.items = self.__parse_items() + + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + # Data Augmentation + self.augmentator = None + self.gaussian_augmentation_config = None + if augmentation_config: + self.data_augmentation_p = augmentation_config["p"] + if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): + self.augmentator = AugmentWAV(ap, augmentation_config) + + if "gaussian" in augmentation_config.keys(): + self.gaussian_augmentation_config = augmentation_config["gaussian"] + + if self.verbose: + print("\n > DataLoader initialization") + print(f" | > Classes per Batch: {num_classes_in_batch}") + print(f" | > Number of instances : {len(self.items)}") + print(f" | > Sequence length: {self.seq_len}") + print(f" | > Num Classes: {len(self.classes)}") + print(f" | > Classes: {self.classes}") + + def load_wav(self, filename): + audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) + return audio + + def __parse_items(self): + class_to_utters = {} + for item in self.items: + path_ = item["audio_file"] + class_name = item[self.config.class_name_key] + if class_name in class_to_utters.keys(): + class_to_utters[class_name].append(path_) + else: + class_to_utters[class_name] = [ + path_, + ] + + # skip classes with number of samples >= self.num_utter_per_class + class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class} + + classes = list(class_to_utters.keys()) + classes.sort() + + new_items = [] + for item in self.items: + path_ = item["audio_file"] + class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"] + # ignore filtered classes + if class_name not in classes: + continue + # ignore small audios + if self.load_wav(path_).shape[0] - self.seq_len <= 0: + continue + + new_items.append({"wav_file_path": path_, "class_name": class_name}) + + return classes, new_items + + def __len__(self): + return len(self.items) + + def get_num_classes(self): + return len(self.classes) + + def get_class_list(self): + return self.classes + + def set_classes(self, classes): + self.classes = classes + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + def get_map_classid_to_classname(self): + return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items()) + + def __getitem__(self, idx): + return self.items[idx] + + def collate_fn(self, batch): + # get the batch class_ids + labels = [] + feats = [] + for item in batch: + utter_path = item["wav_file_path"] + class_name = item["class_name"] + + # get classid + class_id = self.classname_to_classid[class_name] + # load wav file + wav = self.load_wav(utter_path) + offset = random.randint(0, wav.shape[0] - self.seq_len) + wav = wav[offset : offset + self.seq_len] + + if self.augmentator is not None and self.data_augmentation_p: + if random.random() < self.data_augmentation_p: + wav = self.augmentator.apply_one(wav) + + if not self.use_torch_spec: + mel = self.ap.melspectrogram(wav) + feats.append(torch.FloatTensor(mel)) + else: + feats.append(torch.FloatTensor(wav)) + + labels.append(class_id) + + feats = torch.stack(feats) + labels = torch.LongTensor(labels) + + return feats, labels diff --git a/TTS/encoder/losses.py b/TTS/encoder/losses.py new file mode 100644 index 0000000..5b5aa0f --- /dev/null +++ b/TTS/encoder/losses.py @@ -0,0 +1,226 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +# adapted from https://github.com/cvqluu/GE2E-Loss +class GE2ELoss(nn.Module): + def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): + """ + Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector (e.g. d-vector) + Args: + - init_w (float): defines the initial value of w in Equation (5) of [1] + - init_b (float): definies the initial value of b in Equation (5) of [1] + """ + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.loss_method = loss_method + + print(" > Initialized Generalized End-to-End loss") + + assert self.loss_method in ["softmax", "contrast"] + + if self.loss_method == "softmax": + self.embed_loss = self.embed_loss_softmax + if self.loss_method == "contrast": + self.embed_loss = self.embed_loss_contrast + + # pylint: disable=R0201 + def calc_new_centroids(self, dvecs, centroids, spkr, utt): + """ + Calculates the new centroids excluding the reference utterance + """ + excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) + excl = torch.mean(excl, 0) + new_centroids = [] + for i, centroid in enumerate(centroids): + if i == spkr: + new_centroids.append(excl) + else: + new_centroids.append(centroid) + return torch.stack(new_centroids) + + def calc_cosine_sim(self, dvecs, centroids): + """ + Make the cosine similarity matrix with dims (N,M,N) + """ + cos_sim_matrix = [] + for spkr_idx, speaker in enumerate(dvecs): + cs_row = [] + for utt_idx, utterance in enumerate(speaker): + new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx) + # vector based cosine similarity for speed + cs_row.append( + torch.clamp( + torch.mm( + utterance.unsqueeze(1).transpose(0, 1), + new_centroids.transpose(0, 1), + ) + / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), + 1e-6, + ) + ) + cs_row = torch.cat(cs_row, dim=0) + cos_sim_matrix.append(cs_row) + return torch.stack(cos_sim_matrix) + + # pylint: disable=R0201 + def embed_loss_softmax(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by taking softmax + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + # pylint: disable=R0201 + def embed_loss_contrast(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) + excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])) + L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids)) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + def forward(self, x, _label=None): + """ + Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + centroids = torch.mean(x, 1) + cos_sim_matrix = self.calc_cosine_sim(x, centroids) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = self.w * cos_sim_matrix + self.b + L = self.embed_loss(x, cos_sim_matrix) + return L.mean() + + +# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py +class AngleProtoLoss(nn.Module): + """ + Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector + Args: + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, init_w=10.0, init_b=-5.0): + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.criterion = torch.nn.CrossEntropyLoss() + + print(" > Initialized Angular Prototypical loss") + + def forward(self, x, _label=None): + """ + Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + out_anchor = torch.mean(x[:, 1:, :], 1) + out_positive = x[:, 0, :] + num_speakers = out_anchor.size()[0] + + cos_sim_matrix = F.cosine_similarity( + out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), + out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2), + ) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = cos_sim_matrix * self.w + self.b + label = torch.arange(num_speakers).to(cos_sim_matrix.device) + L = self.criterion(cos_sim_matrix, label) + return L + + +class SoftmaxLoss(nn.Module): + """ + Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + """ + + def __init__(self, embedding_dim, n_speakers): + super().__init__() + + self.criterion = torch.nn.CrossEntropyLoss() + self.fc = nn.Linear(embedding_dim, n_speakers) + + print("Initialised Softmax Loss") + + def forward(self, x, label=None): + # reshape for compatibility + x = x.reshape(-1, x.size()[-1]) + label = label.reshape(-1) + + x = self.fc(x) + L = self.criterion(x, label) + + return L + + def inference(self, embedding): + x = self.fc(embedding) + activations = torch.nn.functional.softmax(x, dim=1).squeeze(0) + class_id = torch.argmax(activations) + return class_id + + +class SoftmaxAngleProtoLoss(nn.Module): + """ + Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0): + super().__init__() + + self.softmax = SoftmaxLoss(embedding_dim, n_speakers) + self.angleproto = AngleProtoLoss(init_w, init_b) + + print("Initialised SoftmaxAnglePrototypical Loss") + + def forward(self, x, label=None): + """ + Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + Lp = self.angleproto(x) + + Ls = self.softmax(x, label) + + return Ls + Lp diff --git a/TTS/encoder/models/__pycache__/base_encoder.cpython-311.pyc b/TTS/encoder/models/__pycache__/base_encoder.cpython-311.pyc new file mode 100644 index 0000000..51ea66b Binary files /dev/null and b/TTS/encoder/models/__pycache__/base_encoder.cpython-311.pyc differ diff --git a/TTS/encoder/models/__pycache__/lstm.cpython-311.pyc b/TTS/encoder/models/__pycache__/lstm.cpython-311.pyc new file mode 100644 index 0000000..ad58d0e Binary files /dev/null and b/TTS/encoder/models/__pycache__/lstm.cpython-311.pyc differ diff --git a/TTS/encoder/models/__pycache__/resnet.cpython-311.pyc b/TTS/encoder/models/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000..b52f73e Binary files /dev/null and b/TTS/encoder/models/__pycache__/resnet.cpython-311.pyc differ diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py new file mode 100644 index 0000000..957ea3c --- /dev/null +++ b/TTS/encoder/models/base_encoder.py @@ -0,0 +1,161 @@ +import numpy as np +import torch +import torchaudio +from coqpit import Coqpit +from torch import nn + +from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss +from TTS.utils.generic_utils import set_init_dict +from TTS.utils.io import load_fsspec + + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + + +class BaseEncoder(nn.Module): + """Base `encoder` class. Every new `encoder` model must inherit this. + + It defines common `encoder` specific functions. + """ + + # pylint: disable=W0102 + def __init__(self): + super(BaseEncoder, self).__init__() + + def get_torch_mel_spectrogram_class(self, audio_config): + return torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + + @torch.no_grad() + def inference(self, x, l2_norm=True): + return self.forward(x, l2_norm) + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + # map to the waveform size + if self.use_torch_spec: + num_frames = num_frames * self.audio_config["hop_length"] + + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch, l2_norm=l2_norm) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + return embeddings + + def get_criterion(self, c: Coqpit, num_classes=None): + if c.loss == "ge2e": + criterion = GE2ELoss(loss_method="softmax") + elif c.loss == "angleproto": + criterion = AngleProtoLoss() + elif c.loss == "softmaxproto": + criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) + else: + raise Exception("The %s not is a loss supported" % c.loss) + return criterion + + def load_checkpoint( + self, + config: Coqpit, + checkpoint_path: str, + eval: bool = False, + use_cuda: bool = False, + criterion=None, + cache=False, + ): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + try: + self.load_state_dict(state["model"]) + print(" > Model fully restored. ") + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + print(" > Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"], c) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + try: + criterion.load_state_dict(state["criterion"]) + except (KeyError, RuntimeError) as error: + print(" > Criterion load ignored because of:", error) + + # instance and load the criterion for the encoder classifier in inference time + if ( + eval + and criterion is None + and "criterion" in state + and getattr(config, "map_classid_to_classname", None) is not None + ): + criterion = self.get_criterion(config, len(config.map_classid_to_classname)) + criterion.load_state_dict(state["criterion"]) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion diff --git a/TTS/encoder/models/lstm.py b/TTS/encoder/models/lstm.py new file mode 100644 index 0000000..51852b5 --- /dev/null +++ b/TTS/encoder/models/lstm.py @@ -0,0 +1,99 @@ +import torch +from torch import nn + +from TTS.encoder.models.base_encoder import BaseEncoder + + +class LSTMWithProjection(nn.Module): + def __init__(self, input_size, hidden_size, proj_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.proj_size = proj_size + self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = nn.Linear(hidden_size, proj_size, bias=False) + + def forward(self, x): + self.lstm.flatten_parameters() + o, (_, _) = self.lstm(x) + return self.linear(o) + + +class LSTMWithoutProjection(nn.Module): + def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): + super().__init__() + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) + self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) + self.relu = nn.ReLU() + + def forward(self, x): + _, (hidden, _) = self.lstm(x) + return self.relu(self.linear(hidden[-1])) + + +class LSTMSpeakerEncoder(BaseEncoder): + def __init__( + self, + input_dim, + proj_dim=256, + lstm_dim=768, + num_lstm_layers=3, + use_lstm_with_projection=True, + use_torch_spec=False, + audio_config=None, + ): + super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + layers = [] + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) + self.layers = nn.Sequential(*layers) + else: + self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) + else: + self.torch_spec = None + + self._init_layers() + + def _init_layers(self): + for name, param in self.layers.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0.0) + elif "weight" in name: + nn.init.xavier_normal_(param) + + def forward(self, x, l2_norm=True): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.use_torch_spec: + x.squeeze_(1) + x = self.torch_spec(x) + x = self.instancenorm(x).transpose(1, 2) + d = self.layers(x) + if self.use_lstm_with_projection: + d = d[:, -1] + if l2_norm: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py new file mode 100644 index 0000000..5eafcd6 --- /dev/null +++ b/TTS/encoder/models/resnet.py @@ -0,0 +1,198 @@ +import torch +from torch import nn + +# from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.encoder.models.base_encoder import BaseEncoder + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +class ResNetSpeakerEncoder(BaseEncoder): + """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 + Adapted from: https://github.com/clovaai/voxceleb_trainer + """ + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + use_torch_spec=False, + audio_config=None, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) + else: + self.torch_spec = None + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) + + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x diff --git a/TTS/encoder/requirements.txt b/TTS/encoder/requirements.txt new file mode 100644 index 0000000..a486cc4 --- /dev/null +++ b/TTS/encoder/requirements.txt @@ -0,0 +1,2 @@ +umap-learn +numpy>=1.17.0 diff --git a/TTS/encoder/utils/__init__.py b/TTS/encoder/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TTS/encoder/utils/__pycache__/__init__.cpython-311.pyc b/TTS/encoder/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..ae87314 Binary files /dev/null and b/TTS/encoder/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/TTS/encoder/utils/__pycache__/generic_utils.cpython-311.pyc b/TTS/encoder/utils/__pycache__/generic_utils.cpython-311.pyc new file mode 100644 index 0000000..68db90c Binary files /dev/null and b/TTS/encoder/utils/__pycache__/generic_utils.cpython-311.pyc differ diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py new file mode 100644 index 0000000..236d6fe --- /dev/null +++ b/TTS/encoder/utils/generic_utils.py @@ -0,0 +1,136 @@ +import glob +import os +import random + +import numpy as np +from scipy import signal + +from TTS.encoder.models.lstm import LSTMSpeakerEncoder +from TTS.encoder.models.resnet import ResNetSpeakerEncoder + + +class AugmentWAV(object): + def __init__(self, ap, augmentation_config): + self.ap = ap + self.use_additive_noise = False + + if "additive" in augmentation_config.keys(): + self.additive_noise_config = augmentation_config["additive"] + additive_path = self.additive_noise_config["sounds_path"] + if additive_path: + self.use_additive_noise = True + # get noise types + self.additive_noise_types = [] + for key in self.additive_noise_config.keys(): + if isinstance(self.additive_noise_config[key], dict): + self.additive_noise_types.append(key) + + additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True) + + self.noise_list = {} + + for wav_file in additive_files: + noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0] + # ignore not listed directories + if noise_dir not in self.additive_noise_types: + continue + if not noise_dir in self.noise_list: + self.noise_list[noise_dir] = [] + self.noise_list[noise_dir].append(wav_file) + + print( + f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" + ) + + self.use_rir = False + + if "rir" in augmentation_config.keys(): + self.rir_config = augmentation_config["rir"] + if self.rir_config["rir_path"]: + self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) + self.use_rir = True + + print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") + + self.create_augmentation_global_list() + + def create_augmentation_global_list(self): + if self.use_additive_noise: + self.global_noise_list = self.additive_noise_types + else: + self.global_noise_list = [] + if self.use_rir: + self.global_noise_list.append("RIR_AUG") + + def additive_noise(self, noise_type, audio): + clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4) + + noise_list = random.sample( + self.noise_list[noise_type], + random.randint( + self.additive_noise_config[noise_type]["min_num_noises"], + self.additive_noise_config[noise_type]["max_num_noises"], + ), + ) + + audio_len = audio.shape[0] + noises_wav = None + for noise in noise_list: + noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len] + + if noiseaudio.shape[0] < audio_len: + continue + + noise_snr = random.uniform( + self.additive_noise_config[noise_type]["min_snr_in_db"], + self.additive_noise_config[noise_type]["max_num_noises"], + ) + noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4) + noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio + + if noises_wav is None: + noises_wav = noise_wav + else: + noises_wav += noise_wav + + # if all possible files is less than audio, choose other files + if noises_wav is None: + return self.additive_noise(noise_type, audio) + + return audio + noises_wav + + def reverberate(self, audio): + audio_len = audio.shape[0] + + rir_file = random.choice(self.rir_files) + rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) + rir = rir / np.sqrt(np.sum(rir**2)) + return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] + + def apply_one(self, audio): + noise_type = random.choice(self.global_noise_list) + if noise_type == "RIR_AUG": + return self.reverberate(audio) + + return self.additive_noise(noise_type, audio) + + +def setup_encoder_model(config: "Coqpit"): + if config.model_params["model_name"].lower() == "lstm": + model = LSTMSpeakerEncoder( + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, + ) + elif config.model_params["model_name"].lower() == "resnet": + model = ResNetSpeakerEncoder( + input_dim=config.model_params["input_dim"], + proj_dim=config.model_params["proj_dim"], + log_input=config.model_params.get("log_input", False), + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, + ) + return model diff --git a/TTS/encoder/utils/prepare_voxceleb.py b/TTS/encoder/utils/prepare_voxceleb.py new file mode 100644 index 0000000..b93baf9 --- /dev/null +++ b/TTS/encoder/utils/prepare_voxceleb.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Only support eager mode and TF>=2.0.0 +# pylint: disable=no-member, invalid-name, relative-beyond-top-level +# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes +""" voxceleb 1 & 2 """ + +import hashlib +import os +import subprocess +import sys +import zipfile + +import pandas +import soundfile as sf +from absl import logging + +SUBSETS = { + "vox1_dev_wav": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad", + ], + "vox1_test_wav": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], + "vox2_dev_aac": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah", + ], + "vox2_test_aac": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"], +} + +MD5SUM = { + "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", + "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", + "vox1_test_wav": "185fdc63c3c739954633d50379a3d102", + "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312", +} + +USER = {"user": "", "password": ""} + +speaker_id_dict = {} + + +def download_and_extract(directory, subset, urls): + """Download and extract the given split of dataset. + + Args: + directory: the directory where to put the downloaded data. + subset: subset name of the corpus. + urls: the list of urls to download the data file. + """ + os.makedirs(directory, exist_ok=True) + + try: + for url in urls: + zip_filepath = os.path.join(directory, url.split("/")[-1]) + if os.path.exists(zip_filepath): + continue + logging.info("Downloading %s to %s" % (url, zip_filepath)) + subprocess.call( + "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), + shell=True, + ) + + statinfo = os.stat(zip_filepath) + logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) + + # concatenate all parts into zip files + if ".zip" not in zip_filepath: + zip_filepath = "_".join(zip_filepath.split("_")[:-1]) + subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True) + zip_filepath += ".zip" + extract_path = zip_filepath.strip(".zip") + + # check zip file md5sum + with open(zip_filepath, "rb") as f_zip: + md5 = hashlib.md5(f_zip.read()).hexdigest() + if md5 != MD5SUM[subset]: + raise ValueError("md5sum of %s mismatch" % zip_filepath) + + with zipfile.ZipFile(zip_filepath, "r") as zfile: + zfile.extractall(directory) + extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) + subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True) + finally: + # os.remove(zip_filepath) + pass + + +def exec_cmd(cmd): + """Run a command in a subprocess. + Args: + cmd: command line to be executed. + Return: + int, the return code. + """ + try: + retcode = subprocess.call(cmd, shell=True) + if retcode < 0: + logging.info(f"Child was terminated by signal {retcode}") + except OSError as e: + logging.info(f"Execution failed: {e}") + retcode = -999 + return retcode + + +def decode_aac_with_ffmpeg(aac_file, wav_file): + """Decode a given AAC file into WAV using ffmpeg. + Args: + aac_file: file path to input AAC file. + wav_file: file path to output WAV file. + Return: + bool, True if success. + """ + cmd = f"ffmpeg -i {aac_file} {wav_file}" + logging.info(f"Decoding aac file using command line: {cmd}") + ret = exec_cmd(cmd) + if ret != 0: + logging.error(f"Failed to decode aac file with retcode {ret}") + logging.error("Please check your ffmpeg installation.") + return False + return True + + +def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): + """Optionally convert AAC to WAV and make speaker labels. + Args: + input_dir: the directory which holds the input dataset. + subset: the name of the specified subset. e.g. vox1_dev_wav + output_dir: the directory to place the newly generated csv files. + output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv + """ + + logging.info("Preprocessing audio and label for subset %s" % subset) + source_dir = os.path.join(input_dir, subset) + + files = [] + # Convert all AAC file into WAV format. At the same time, generate the csv + for root, _, filenames in os.walk(source_dir): + for filename in filenames: + name, ext = os.path.splitext(filename) + if ext.lower() == ".wav": + _, ext2 = os.path.splitext(name) + if ext2: + continue + wav_file = os.path.join(root, filename) + elif ext.lower() == ".m4a": + # Convert AAC to WAV. + aac_file = os.path.join(root, filename) + wav_file = aac_file + ".wav" + if not os.path.exists(wav_file): + if not decode_aac_with_ffmpeg(aac_file, wav_file): + raise RuntimeError("Audio decoding failed.") + else: + continue + speaker_name = root.split(os.path.sep)[-2] + if speaker_name not in speaker_id_dict: + num = len(speaker_id_dict) + speaker_id_dict[speaker_name] = num + # wav_filesize = os.path.getsize(wav_file) + wav_length = len(sf.read(wav_file)[0]) + files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)) + + # Write to CSV file which contains four columns: + # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". + csv_file_path = os.path.join(output_dir, output_file) + df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) + df.to_csv(csv_file_path, index=False, sep="\t") + logging.info("Successfully generated csv file {}".format(csv_file_path)) + + +def processor(directory, subset, force_process): + """download and process""" + urls = SUBSETS + if subset not in urls: + raise ValueError(subset, "is not in voxceleb") + + subset_csv = os.path.join(directory, subset + ".csv") + if not force_process and os.path.exists(subset_csv): + return subset_csv + + logging.info("Downloading and process the voxceleb in %s", directory) + logging.info("Preparing subset %s", subset) + download_and_extract(directory, subset, urls[subset]) + convert_audio_and_make_label(directory, subset, directory, subset + ".csv") + logging.info("Finished downloading and processing") + return subset_csv + + +if __name__ == "__main__": + logging.set_verbosity(logging.INFO) + if len(sys.argv) != 4: + print("Usage: python prepare_data.py save_directory user password") + sys.exit() + + DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3] + for SUBSET in SUBSETS: + processor(DIR, SUBSET, False) diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py new file mode 100644 index 0000000..ff8f271 --- /dev/null +++ b/TTS/encoder/utils/training.py @@ -0,0 +1,99 @@ +import os +from dataclasses import dataclass, field + +from coqpit import Coqpit +from trainer import TrainerArgs, get_last_checkpoint +from trainer.io import copy_model_files +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger + +from TTS.config import load_config, register_config +from TTS.tts.utils.text.characters import parse_symbols +from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch + + +@dataclass +class TrainArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def getarguments(): + train_config = TrainArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser + + +def process_args(args, config=None): + """Process parsed comand line arguments and initialize the config if not provided. + Args: + args (argparse.Namespace or dict like): Parsed input arguments. + config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. + Returns: + c (TTS.utils.io.AttrDict): Config paramaters. + out_path (str): Path to save models and logging. + audio_path (str): Path to save generated test audios. + c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does + logging to the console. + dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging + TODO: + - Interactive config definition. + """ + if isinstance(args, tuple): + args, coqpit_overrides = args + if args.continue_path: + # continue a previous training from its output folder + experiment_path = args.continue_path + args.config_path = os.path.join(args.continue_path, "config.json") + args.restore_path, best_model = get_last_checkpoint(args.continue_path) + if not args.best_path: + args.best_path = best_model + # init config if not already defined + if config is None: + if args.config_path: + # init from a file + config = load_config(args.config_path) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(coqpit_overrides) + config = register_config(config_base.model)() + # override values from command-line args + config.parse_known_args(coqpit_overrides, relaxed_parser=True) + experiment_path = args.continue_path + if not experiment_path: + experiment_path = get_experiment_folder_path(config.output_path, config.run_name) + audio_path = os.path.join(experiment_path, "test_audios") + config.output_log_path = experiment_path + # setup rank 0 process in distributed training + dashboard_logger = None + if args.rank == 0: + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + # if model characters are not set in the config file + # save the default set to the config file for future + # compatibility. + if config.has("characters") and config.characters is None: + used_characters = parse_symbols() + new_fields["characters"] = used_characters + copy_model_files(config, experiment_path, new_fields) + dashboard_logger = logger_factory(config, experiment_path) + c_logger = ConsoleLogger() + return config, experiment_path, audio_path, c_logger, dashboard_logger + + +def init_arguments(): + train_config = TrainArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser + + +def init_training(config: Coqpit = None): + """Initialization of a training run.""" + parser = init_arguments() + args = parser.parse_known_args() + config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config) + return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger diff --git a/TTS/encoder/utils/visual.py b/TTS/encoder/utils/visual.py new file mode 100644 index 0000000..6575b86 --- /dev/null +++ b/TTS/encoder/utils/visual.py @@ -0,0 +1,50 @@ +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import umap + +matplotlib.use("Agg") + + +colormap = ( + np.array( + [ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + ], + dtype=float, + ) + / 255 +) + + +def plot_embeddings(embeddings, num_classes_in_batch): + num_utter_per_class = embeddings.shape[0] // num_classes_in_batch + + # if necessary get just the first 10 classes + if num_classes_in_batch > 10: + num_classes_in_batch = 10 + embeddings = embeddings[: num_classes_in_batch * num_utter_per_class] + + model = umap.UMAP() + projection = model.fit_transform(embeddings) + ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class) + colors = [colormap[i] for i in ground_truth] + fig, ax = plt.subplots(figsize=(16, 10)) + _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection") + plt.tight_layout() + plt.savefig("umap") + return fig diff --git a/TTS/model.py b/TTS/model.py new file mode 100644 index 0000000..ae6be7b --- /dev/null +++ b/TTS/model.py @@ -0,0 +1,59 @@ +from abc import abstractmethod +from typing import Dict + +import torch +from coqpit import Coqpit +from trainer import TrainerModel + +# pylint: skip-file + + +class BaseTrainerModel(TrainerModel): + """BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS. + + Every new 🐸TTS model must inherit it. + """ + + @staticmethod + @abstractmethod + def init_from_config(config: Coqpit): + """Init the model and all its attributes from the given config. + + Override this depending on your model. + """ + ... + + @abstractmethod + def inference(self, input: torch.Tensor, aux_input={}) -> Dict: + """Forward pass for inference. + + It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` + is considered to be the main output and you can add any other auxiliary outputs as you want. + + We don't use `*kwargs` since it is problematic with the TorchScript API. + + Args: + input (torch.Tensor): [description] + aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc. + + Returns: + Dict: [description] + """ + outputs_dict = {"model_outputs": None} + ... + return outputs_dict + + @abstractmethod + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ) -> None: + """Load a model checkpoint gile and get ready for training or inference. + + Args: + config (Coqpit): Model configuration. + checkpoint_path (str): Path to the model checkpoint file. + eval (bool, optional): If true, init model for inference else for training. Defaults to False. + strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. + cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. + """ + ... diff --git a/TTS/utils/__init__.py b/TTS/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TTS/utils/__pycache__/__init__.cpython-311.pyc b/TTS/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..4abc14c Binary files /dev/null and b/TTS/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/TTS/utils/__pycache__/generic_utils.cpython-311.pyc b/TTS/utils/__pycache__/generic_utils.cpython-311.pyc new file mode 100644 index 0000000..6b8a83c Binary files /dev/null and b/TTS/utils/__pycache__/generic_utils.cpython-311.pyc differ diff --git a/TTS/utils/__pycache__/io.cpython-311.pyc b/TTS/utils/__pycache__/io.cpython-311.pyc new file mode 100644 index 0000000..decdb56 Binary files /dev/null and b/TTS/utils/__pycache__/io.cpython-311.pyc differ diff --git a/TTS/utils/__pycache__/manage.cpython-311.pyc b/TTS/utils/__pycache__/manage.cpython-311.pyc new file mode 100644 index 0000000..4b15463 Binary files /dev/null and b/TTS/utils/__pycache__/manage.cpython-311.pyc differ diff --git a/TTS/utils/__pycache__/samplers.cpython-311.pyc b/TTS/utils/__pycache__/samplers.cpython-311.pyc new file mode 100644 index 0000000..d93fa4b Binary files /dev/null and b/TTS/utils/__pycache__/samplers.cpython-311.pyc differ diff --git a/TTS/utils/__pycache__/synthesizer.cpython-311.pyc b/TTS/utils/__pycache__/synthesizer.cpython-311.pyc new file mode 100644 index 0000000..9684e8e Binary files /dev/null and b/TTS/utils/__pycache__/synthesizer.cpython-311.pyc differ diff --git a/TTS/utils/audio/__init__.py b/TTS/utils/audio/__init__.py new file mode 100644 index 0000000..f18f221 --- /dev/null +++ b/TTS/utils/audio/__init__.py @@ -0,0 +1 @@ +from TTS.utils.audio.processor import AudioProcessor diff --git a/TTS/utils/audio/__pycache__/__init__.cpython-311.pyc b/TTS/utils/audio/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..1060a43 Binary files /dev/null and b/TTS/utils/audio/__pycache__/__init__.cpython-311.pyc differ diff --git a/TTS/utils/audio/__pycache__/numpy_transforms.cpython-311.pyc b/TTS/utils/audio/__pycache__/numpy_transforms.cpython-311.pyc new file mode 100644 index 0000000..1856323 Binary files /dev/null and b/TTS/utils/audio/__pycache__/numpy_transforms.cpython-311.pyc differ diff --git a/TTS/utils/audio/__pycache__/processor.cpython-311.pyc b/TTS/utils/audio/__pycache__/processor.cpython-311.pyc new file mode 100644 index 0000000..faa74a5 Binary files /dev/null and b/TTS/utils/audio/__pycache__/processor.cpython-311.pyc differ diff --git a/TTS/utils/audio/__pycache__/torch_transforms.cpython-311.pyc b/TTS/utils/audio/__pycache__/torch_transforms.cpython-311.pyc new file mode 100644 index 0000000..0904835 Binary files /dev/null and b/TTS/utils/audio/__pycache__/torch_transforms.cpython-311.pyc differ diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py new file mode 100644 index 0000000..af88569 --- /dev/null +++ b/TTS/utils/audio/numpy_transforms.py @@ -0,0 +1,485 @@ +from io import BytesIO +from typing import Tuple + +import librosa +import numpy as np +import scipy +import soundfile as sf +from librosa import magphase, pyin + +# For using kwargs +# pylint: disable=unused-argument + + +def build_mel_basis( + *, + sample_rate: int = None, + fft_size: int = None, + num_mels: int = None, + mel_fmax: int = None, + mel_fmin: int = None, + **kwargs, +) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if mel_fmax is not None: + assert mel_fmax <= sample_rate // 2 + assert mel_fmax - mel_fmin > 0 + return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax) + + +def millisec_to_length( + *, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs +) -> Tuple[int, int]: + """Compute hop and window length from milliseconds. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = frame_length_ms / frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + win_length = int(frame_length_ms / 1000.0 * sample_rate) + hop_length = int(win_length / float(factor)) + return win_length, hop_length + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) + + +def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Decibels spectrogram. + """ + assert (x < 0).sum() == 0, " [!] Input values must be non-negative." + return gain * _log(np.maximum(1e-8, x), base) + + +# pylint: disable=no-self-use +def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / gain, base) + + +def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -coef], [1], x) + + +def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray: + """Reverse pre-emphasis.""" + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -coef], x) + + +def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + spec (np.ndarray): Normalized full scale linear spectrogram. + + Shapes: + - spec: :math:`[C, T]` + + Returns: + np.ndarray: Normalized melspectrogram. + """ + return np.dot(mel_basis, spec) + + +def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + assert (mel < 0).sum() == 0, " [!] Input values must be non-negative." + inv_mel_basis = np.linalg.pinv(mel_basis) + return np.maximum(1e-10, np.dot(inv_mel_basis, mel)) + + +def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + wav (np.ndarray): Waveform. Shape :math:`[T_wav,]` + + Returns: + np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length` + """ + D = stft(y=wav, **kwargs) + S = np.abs(D) + return S.astype(np.float32) + + +def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + D = stft(y=wav, **kwargs) + S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs) + return S.astype(np.float32) + + +def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = spec.copy() + return griffin_lim(spec=S**power, **kwargs) + + +def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + S = mel.copy() + S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear + return griffin_lim(spec=S**power, **kwargs) + + +### STFT and ISTFT ### +def stft( + *, + y: np.ndarray = None, + fft_size: int = None, + hop_length: int = None, + win_length: int = None, + pad_mode: str = "reflect", + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa STFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.stft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=fft_size, + hop_length=hop_length, + win_length=win_length, + pad_mode=pad_mode, + window=window, + center=center, + ) + + +def istft( + *, + y: np.ndarray = None, + hop_length: int = None, + win_length: int = None, + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa iSTFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.istft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window) + + +def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray: + angles = np.exp(2j * np.pi * np.random.rand(*spec.shape)) + S_complex = np.abs(spec).astype(complex) + y = istft(y=S_complex * angles, **kwargs) + if not np.isfinite(y).all(): + print(" [!] Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(num_iter): + angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) + y = istft(y=S_complex * angles, **kwargs) + return y + + +def compute_stft_paddings( + *, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs +) -> Tuple[int, int]: + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0] + if not pad_two_sides: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + +def compute_f0( + *, + x: np.ndarray = None, + pitch_fmax: float = None, + pitch_fmin: float = None, + hop_length: int = None, + win_length: int = None, + sample_rate: int = None, + stft_pad_mode: str = "reflect", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. Shape :math:`[T_wav,]` + pitch_fmax (float): Pitch max value. + pitch_fmin (float): Pitch min value. + hop_length (int): Number of frames between STFT columns. + win_length (int): STFT window length. + sample_rate (int): Audio sampling rate. + stft_pad_mode (str): Padding mode for STFT. + center (bool): Centered padding. + + Returns: + np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length` + + Examples: + >>> WAV_FILE = filename = librosa.example('vibeace') + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> pitch = ap.compute_f0(wav) + """ + assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." + assert pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`." + + f0, voiced_mask, _ = pyin( + y=x.astype(np.double), + fmin=pitch_fmin, + fmax=pitch_fmax, + sr=sample_rate, + frame_length=win_length, + win_length=win_length // 2, + hop_length=hop_length, + pad_mode=stft_pad_mode, + center=center, + n_thresholds=100, + beta_parameters=(2, 18), + boltzmann_parameter=2, + resolution=0.1, + max_transition_rate=35.92, + switch_prob=0.01, + no_trough_prob=0.01, + ) + f0[~voiced_mask] = 0.0 + + return f0 + + +def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray: + """Compute energy of a waveform using the same parameters used for computing melspectrogram. + Args: + x (np.ndarray): Waveform. Shape :math:`[T_wav,]` + Returns: + np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length` + Examples: + >>> WAV_FILE = filename = librosa.example('vibeace') + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig() + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> energy = ap.compute_energy(wav) + """ + x = stft(y=y, **kwargs) + mag, _ = magphase(x) + energy = np.sqrt(np.sum(mag**2, axis=0)) + return energy + + +### Audio Processing ### +def find_endpoint( + *, + wav: np.ndarray = None, + trim_db: float = -40, + sample_rate: int = None, + min_silence_sec=0.8, + gain: float = None, + base: int = None, + **kwargs, +) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None. + base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10. + + Returns: + int: Last point without silence. + """ + window_length = int(sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = db_to_amp(x=-trim_db, gain=gain, base=base) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + +def trim_silence( + *, + wav: np.ndarray = None, + sample_rate: int = None, + trim_db: float = None, + win_length: int = None, + hop_length: int = None, + **kwargs, +) -> np.ndarray: + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0] + + +def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + coef (float): Coefficient to rescale the maximum value. Defaults to 0.95. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * coef + + +def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray: + r = 10 ** (db_level / 20) + a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) + return wav * a + + +def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + db_level (float): Target dB level in RMS. Defaults to -27.0. + + Returns: + np.ndarray: RMS normalized waveform. + """ + assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" + wav = rms_norm(wav=x, db_level=db_level) + return wav + + +def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False. + + Returns: + np.ndarray: Loaded waveform. + """ + if resample: + # loading with resampling. It is significantly slower. + x, _ = librosa.load(filename, sr=sample_rate) + else: + # SF is faster than librosa for loading files + x, _ = sf.read(filename) + return x + + +def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None: + """Save float waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform with float values in range [-1, 1] to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + """ + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + + wav_norm = wav_norm.astype(np.int16) + if pipe_out: + wav_buffer = BytesIO() + scipy.io.wavfile.write(wav_buffer, sample_rate, wav_norm) + wav_buffer.seek(0) + pipe_out.buffer.write(wav_buffer.read()) + scipy.io.wavfile.write(path, sample_rate, wav_norm) + + +def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray: + mu = 2**mulaw_qc - 1 + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + +def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray: + """Recovers waveform from quantized values.""" + mu = 2**mulaw_qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + +def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray: + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) + + +def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + quantize_bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2**quantize_bits - 1) / 2 + + +def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray: + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2**quantize_bits - 1) - 1 diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py new file mode 100644 index 0000000..c53bad5 --- /dev/null +++ b/TTS/utils/audio/processor.py @@ -0,0 +1,633 @@ +from io import BytesIO +from typing import Dict, Tuple + +import librosa +import numpy as np +import scipy.io.wavfile +import scipy.signal + +from TTS.tts.utils.helpers import StandardScaler +from TTS.utils.audio.numpy_transforms import ( + amp_to_db, + build_mel_basis, + compute_f0, + db_to_amp, + deemphasis, + find_endpoint, + griffin_lim, + load_wav, + mel_to_spec, + millisec_to_length, + preemphasis, + rms_volume_norm, + spec_to_mel, + stft, + trim_silence, + volume_norm, +) + +# pylint: disable=too-many-public-methods + + +class AudioProcessor(object): + """Audio Processor for TTS. + + Note: + All the class arguments are set to default values to enable a flexible initialization + of the class with the model config. They are not meaningful for all the arguments. + + Args: + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + resample (bool, optional): + enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. + + num_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + log_func (int, optional): + log exponent used for converting spectrogram aplitude to DB. + + min_level_db (int, optional): + minimum db threshold for the computed melspectrograms. Defaults to None. + + frame_shift_ms (int, optional): + milliseconds of frames between STFT columns. Defaults to None. + + frame_length_ms (int, optional): + milliseconds of STFT window length. Defaults to None. + + hop_length (int, optional): + number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None. + + win_length (int, optional): + STFT window length. Used if ```frame_length_ms``` is None. Defaults to None. + + ref_level_db (int, optional): + reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None. + + fft_size (int, optional): + FFT window size for STFT. Defaults to 1024. + + power (int, optional): + Exponent value applied to the spectrogram before GriffinLim. Defaults to None. + + preemphasis (float, optional): + Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0. + + signal_norm (bool, optional): + enable/disable signal normalization. Defaults to None. + + symmetric_norm (bool, optional): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None. + + max_norm (float, optional): + ```k``` defining the normalization range. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + pitch_fmin (int, optional): + minimum filter frequency for computing pitch. Defaults to None. + + pitch_fmax (int, optional): + maximum filter frequency for computing pitch. Defaults to None. + + spec_gain (int, optional): + gain applied when converting amplitude to DB. Defaults to 20. + + stft_pad_mode (str, optional): + Padding mode for STFT. Defaults to 'reflect'. + + clip_norm (bool, optional): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + griffin_lim_iters (int, optional): + Number of GriffinLim iterations. Defaults to None. + + do_trim_silence (bool, optional): + enable/disable silence trimming when loading the audio signal. Defaults to False. + + trim_db (int, optional): + DB threshold used for silence trimming. Defaults to 60. + + do_sound_norm (bool, optional): + enable/disable signal normalization. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + + stats_path (str, optional): + Path to the computed stats file. Defaults to None. + + verbose (bool, optional): + enable/disable logging. Defaults to True. + + """ + + def __init__( + self, + sample_rate=None, + resample=False, + num_mels=None, + log_func="np.log10", + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + hop_length=None, + win_length=None, + ref_level_db=None, + fft_size=1024, + power=None, + preemphasis=0.0, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + pitch_fmax=None, + pitch_fmin=None, + spec_gain=20, + stft_pad_mode="reflect", + clip_norm=True, + griffin_lim_iters=None, + do_trim_silence=False, + trim_db=60, + do_sound_norm=False, + do_amp_to_db_linear=True, + do_amp_to_db_mel=True, + do_rms_norm=False, + db_level=None, + stats_path=None, + verbose=True, + **_, + ): + # setup class attributed + self.sample_rate = sample_rate + self.resample = resample + self.num_mels = num_mels + self.log_func = log_func + self.min_level_db = min_level_db or 0 + self.frame_shift_ms = frame_shift_ms + self.frame_length_ms = frame_length_ms + self.ref_level_db = ref_level_db + self.fft_size = fft_size + self.power = power + self.preemphasis = preemphasis + self.griffin_lim_iters = griffin_lim_iters + self.signal_norm = signal_norm + self.symmetric_norm = symmetric_norm + self.mel_fmin = mel_fmin or 0 + self.mel_fmax = mel_fmax + self.pitch_fmin = pitch_fmin + self.pitch_fmax = pitch_fmax + self.spec_gain = float(spec_gain) + self.stft_pad_mode = stft_pad_mode + self.max_norm = 1.0 if max_norm is None else float(max_norm) + self.clip_norm = clip_norm + self.do_trim_silence = do_trim_silence + self.trim_db = trim_db + self.do_sound_norm = do_sound_norm + self.do_amp_to_db_linear = do_amp_to_db_linear + self.do_amp_to_db_mel = do_amp_to_db_mel + self.do_rms_norm = do_rms_norm + self.db_level = db_level + self.stats_path = stats_path + # setup exp_func for db to amp conversion + if log_func == "np.log": + self.base = np.e + elif log_func == "np.log10": + self.base = 10 + else: + raise ValueError(" [!] unknown `log_func` value.") + # setup stft parameters + if hop_length is None: + # compute stft parameters from given time values + self.win_length, self.hop_length = millisec_to_length( + frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate + ) + else: + # use stft parameters from config file + self.hop_length = hop_length + self.win_length = win_length + assert min_level_db != 0.0, " [!] min_level_db is 0" + assert ( + self.win_length <= self.fft_size + ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}" + members = vars(self) + if verbose: + print(" > Setting up Audio Processor...") + for key, value in members.items(): + print(" | > {}:{}".format(key, value)) + # create spectrogram utils + self.mel_basis = build_mel_basis( + sample_rate=self.sample_rate, + fft_size=self.fft_size, + num_mels=self.num_mels, + mel_fmax=self.mel_fmax, + mel_fmin=self.mel_fmin, + ) + # setup scaler + if stats_path and signal_norm: + mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) + self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std) + self.signal_norm = True + self.max_norm = None + self.clip_norm = None + self.symmetric_norm = None + + @staticmethod + def init_from_config(config: "Coqpit", verbose=True): + if "audio" in config: + return AudioProcessor(verbose=verbose, **config.audio) + return AudioProcessor(verbose=verbose, **config) + + ### normalization ### + def normalize(self, S: np.ndarray) -> np.ndarray: + """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` + + Args: + S (np.ndarray): Spectrogram to normalize. + + Raises: + RuntimeError: Mean and variance is computed from incompatible parameters. + + Returns: + np.ndarray: Normalized spectrogram. + """ + # pylint: disable=no-else-return + S = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S.shape[0] == self.num_mels: + return self.mel_scaler.transform(S.T).T + elif S.shape[0] == self.fft_size / 2: + return self.linear_scaler.transform(S.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + # range normalization + S -= self.ref_level_db # discard certain range of DB assuming it is air noise + S_norm = (S - self.min_level_db) / (-self.min_level_db) + if self.symmetric_norm: + S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm + if self.clip_norm: + S_norm = np.clip( + S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + return S_norm + else: + S_norm = self.max_norm * S_norm + if self.clip_norm: + S_norm = np.clip(S_norm, 0, self.max_norm) + return S_norm + else: + return S + + def denormalize(self, S: np.ndarray) -> np.ndarray: + """Denormalize spectrogram values. + + Args: + S (np.ndarray): Spectrogram to denormalize. + + Raises: + RuntimeError: Mean and variance are incompatible. + + Returns: + np.ndarray: Denormalized spectrogram. + """ + # pylint: disable=no-else-return + S_denorm = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S_denorm.shape[0] == self.num_mels: + return self.mel_scaler.inverse_transform(S_denorm.T).T + elif S_denorm.shape[0] == self.fft_size / 2: + return self.linear_scaler.inverse_transform(S_denorm.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + if self.symmetric_norm: + if self.clip_norm: + S_denorm = np.clip( + S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db + return S_denorm + self.ref_level_db + else: + if self.clip_norm: + S_denorm = np.clip(S_denorm, 0, self.max_norm) + S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db + return S_denorm + self.ref_level_db + else: + return S_denorm + + ### Mean-STD scaling ### + def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]: + """Loading mean and variance statistics from a `npy` file. + + Args: + stats_path (str): Path to the `npy` file containing + + Returns: + Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to + compute them. + """ + stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg + mel_mean = stats["mel_mean"] + mel_std = stats["mel_std"] + linear_mean = stats["linear_mean"] + linear_std = stats["linear_std"] + stats_config = stats["audio_config"] + # check all audio parameters used for computing stats + skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] + for key in stats_config.keys(): + if key in skip_parameters: + continue + if key not in ["sample_rate", "trim_db"]: + assert ( + stats_config[key] == self.__dict__[key] + ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + return mel_mean, mel_std, linear_mean, linear_std, stats_config + + # pylint: disable=attribute-defined-outside-init + def setup_scaler( + self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray + ) -> None: + """Initialize scaler objects used in mean-std normalization. + + Args: + mel_mean (np.ndarray): Mean for melspectrograms. + mel_std (np.ndarray): STD for melspectrograms. + linear_mean (np.ndarray): Mean for full scale spectrograms. + linear_std (np.ndarray): STD for full scale spectrograms. + """ + self.mel_scaler = StandardScaler() + self.mel_scaler.set_stats(mel_mean, mel_std) + self.linear_scaler = StandardScaler() + self.linear_scaler.set_stats(linear_mean, linear_std) + + ### Preemphasis ### + def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + return preemphasis(x=x, coef=self.preemphasis) + + def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Reverse pre-emphasis.""" + return deemphasis(x=x, coef=self.preemphasis) + + ### SPECTROGRAMs ### + def spectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + y (np.ndarray): Waveform. + + Returns: + np.ndarray: Spectrogram. + """ + if self.preemphasis != 0: + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) + if self.do_amp_to_db_linear: + S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base) + else: + S = np.abs(D) + return self.normalize(S).astype(np.float32) + + def melspectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + if self.preemphasis != 0: + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) + S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis) + if self.do_amp_to_db_mel: + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) + + return self.normalize(S).astype(np.float32) + + def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = self.denormalize(spectrogram) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) + # Reconstruct phase + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W + + def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + D = self.denormalize(mel_spectrogram) + S = db_to_amp(x=D, gain=self.spec_gain, base=self.base) + S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W + + def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + linear_spec (np.ndarray): Normalized full scale linear spectrogram. + + Returns: + np.ndarray: Normalized melspectrogram. + """ + S = self.denormalize(linear_spec) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) + S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis) + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) + mel = self.normalize(S) + return mel + + def _griffin_lim(self, S): + return griffin_lim( + spec=S, + num_iter=self.griffin_lim_iters, + hop_length=self.hop_length, + win_length=self.win_length, + fft_size=self.fft_size, + pad_mode=self.stft_pad_mode, + ) + + def compute_f0(self, x: np.ndarray) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.example('vibeace') + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> pitch = ap.compute_f0(wav) + """ + # align F0 length to the spectrogram length + if len(x) % self.hop_length == 0: + x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode) + + f0 = compute_f0( + x=x, + pitch_fmax=self.pitch_fmax, + pitch_fmin=self.pitch_fmin, + hop_length=self.hop_length, + win_length=self.win_length, + sample_rate=self.sample_rate, + stft_pad_mode=self.stft_pad_mode, + center=True, + ) + + return f0 + + ### Audio Processing ### + def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + + Returns: + int: Last point without silence. + """ + return find_endpoint( + wav=wav, + trim_db=self.trim_db, + sample_rate=self.sample_rate, + min_silence_sec=min_silence_sec, + gain=self.spec_gain, + base=self.base, + ) + + def trim_silence(self, wav): + """Trim silent parts with a threshold and 0.01 sec margin""" + return trim_silence( + wav=wav, + sample_rate=self.sample_rate, + trim_db=self.trim_db, + win_length=self.win_length, + hop_length=self.hop_length, + ) + + @staticmethod + def sound_norm(x: np.ndarray) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return volume_norm(x=x) + + def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: RMS normalized waveform. + """ + if db_level is None: + db_level = self.db_level + return rms_volume_norm(x=x, db_level=db_level) + + ### save and load ### + def load_wav(self, filename: str, sr: int = None) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + + Returns: + np.ndarray: Loaded waveform. + """ + if sr is not None: + x = load_wav(filename=filename, sample_rate=sr, resample=True) + else: + x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample) + if self.do_trim_silence: + try: + x = self.trim_silence(x) + except ValueError: + print(f" [!] File cannot be trimmed for silence - {filename}") + if self.do_sound_norm: + x = self.sound_norm(x) + if self.do_rms_norm: + x = self.rms_volume_norm(x, self.db_level) + return x + + def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None: + """Save a waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + """ + if self.do_rms_norm: + wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767 + else: + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + + wav_norm = wav_norm.astype(np.int16) + if pipe_out: + wav_buffer = BytesIO() + scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm) + wav_buffer.seek(0) + pipe_out.buffer.write(wav_buffer.read()) + scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm) + + def get_duration(self, filename: str) -> float: + """Get the duration of a wav file using Librosa. + + Args: + filename (str): Path to the wav file. + """ + return librosa.get_duration(filename=filename) diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py new file mode 100644 index 0000000..fd40ebb --- /dev/null +++ b/TTS/utils/audio/torch_transforms.py @@ -0,0 +1,165 @@ +import librosa +import torch +from torch import nn + + +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """Some of the audio processing funtions using Torch for faster batch processing. + + Args: + + n_fft (int): + FFT window size for STFT. + + hop_length (int): + number of frames between STFT columns. + + win_length (int, optional): + STFT window length. + + pad_wav (bool, optional): + If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. + + window (str, optional): + The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" + + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + n_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + use_mel (bool, optional): + If True compute the melspectrograms otherwise. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. + + spec_gain (float, optional): + gain applied when converting amplitude to DB. Defaults to 1.0. + + power (float, optional): + Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. + + use_htk (bool, optional): + Use HTK formula in mel filter instead of Slaney. + + mel_norm (None, 'slaney', or number, optional): + If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). + + If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. + See `librosa.util.normalize` for a full description of supported norm values + (including `+-np.inf`). + + Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". + """ + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + do_amp_to_db=False, + spec_gain=1.0, + power=None, + use_htk=False, + mel_norm="slaney", + normalized=False, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.do_amp_to_db = do_amp_to_db + self.spec_gain = spec_gain + self.power = power + self.use_htk = use_htk + self.mel_norm = mel_norm + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + self.normalized = normalized + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [:math:`[B, 1, T]`] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=self.normalized, + onesided=True, + return_complex=False, + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) + + if self.power is not None: + S = S**self.power + + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + if self.do_amp_to_db: + S = self._amp_to_db(S, spec_gain=self.spec_gain) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + sr=self.sample_rate, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + htk=self.use_htk, + norm=self.mel_norm, + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + @staticmethod + def _amp_to_db(x, spec_gain=1.0): + return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + + @staticmethod + def _db_to_amp(x, spec_gain=1.0): + return torch.exp(x) / spec_gain diff --git a/TTS/utils/callbacks.py b/TTS/utils/callbacks.py new file mode 100644 index 0000000..511d215 --- /dev/null +++ b/TTS/utils/callbacks.py @@ -0,0 +1,105 @@ +class TrainerCallback: + @staticmethod + def on_init_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_start"): + trainer.model.module.on_init_start(trainer) + else: + if hasattr(trainer.model, "on_init_start"): + trainer.model.on_init_start(trainer) + + if hasattr(trainer.criterion, "on_init_start"): + trainer.criterion.on_init_start(trainer) + + if hasattr(trainer.optimizer, "on_init_start"): + trainer.optimizer.on_init_start(trainer) + + @staticmethod + def on_init_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_end"): + trainer.model.module.on_init_end(trainer) + else: + if hasattr(trainer.model, "on_init_end"): + trainer.model.on_init_end(trainer) + + if hasattr(trainer.criterion, "on_init_end"): + trainer.criterion.on_init_end(trainer) + + if hasattr(trainer.optimizer, "on_init_end"): + trainer.optimizer.on_init_end(trainer) + + @staticmethod + def on_epoch_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_start"): + trainer.model.module.on_epoch_start(trainer) + else: + if hasattr(trainer.model, "on_epoch_start"): + trainer.model.on_epoch_start(trainer) + + if hasattr(trainer.criterion, "on_epoch_start"): + trainer.criterion.on_epoch_start(trainer) + + if hasattr(trainer.optimizer, "on_epoch_start"): + trainer.optimizer.on_epoch_start(trainer) + + @staticmethod + def on_epoch_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_end"): + trainer.model.module.on_epoch_end(trainer) + else: + if hasattr(trainer.model, "on_epoch_end"): + trainer.model.on_epoch_end(trainer) + + if hasattr(trainer.criterion, "on_epoch_end"): + trainer.criterion.on_epoch_end(trainer) + + if hasattr(trainer.optimizer, "on_epoch_end"): + trainer.optimizer.on_epoch_end(trainer) + + @staticmethod + def on_train_step_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_start"): + trainer.model.module.on_train_step_start(trainer) + else: + if hasattr(trainer.model, "on_train_step_start"): + trainer.model.on_train_step_start(trainer) + + if hasattr(trainer.criterion, "on_train_step_start"): + trainer.criterion.on_train_step_start(trainer) + + if hasattr(trainer.optimizer, "on_train_step_start"): + trainer.optimizer.on_train_step_start(trainer) + + @staticmethod + def on_train_step_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_end"): + trainer.model.module.on_train_step_end(trainer) + else: + if hasattr(trainer.model, "on_train_step_end"): + trainer.model.on_train_step_end(trainer) + + if hasattr(trainer.criterion, "on_train_step_end"): + trainer.criterion.on_train_step_end(trainer) + + if hasattr(trainer.optimizer, "on_train_step_end"): + trainer.optimizer.on_train_step_end(trainer) + + @staticmethod + def on_keyboard_interrupt(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_keyboard_interrupt"): + trainer.model.module.on_keyboard_interrupt(trainer) + else: + if hasattr(trainer.model, "on_keyboard_interrupt"): + trainer.model.on_keyboard_interrupt(trainer) + + if hasattr(trainer.criterion, "on_keyboard_interrupt"): + trainer.criterion.on_keyboard_interrupt(trainer) + + if hasattr(trainer.optimizer, "on_keyboard_interrupt"): + trainer.optimizer.on_keyboard_interrupt(trainer) diff --git a/TTS/utils/capacitron_optimizer.py b/TTS/utils/capacitron_optimizer.py new file mode 100644 index 0000000..7206ffd --- /dev/null +++ b/TTS/utils/capacitron_optimizer.py @@ -0,0 +1,67 @@ +from typing import Generator + +from trainer.trainer_utils import get_optimizer + + +class CapacitronOptimizer: + """Double optimizer class for the Capacitron model.""" + + def __init__(self, config: dict, model_params: Generator) -> None: + self.primary_params, self.secondary_params = self.split_model_parameters(model_params) + + optimizer_names = list(config.optimizer_params.keys()) + optimizer_parameters = list(config.optimizer_params.values()) + + self.primary_optimizer = get_optimizer( + optimizer_names[0], + optimizer_parameters[0], + config.lr, + parameters=self.primary_params, + ) + + self.secondary_optimizer = get_optimizer( + optimizer_names[1], + self.extract_optimizer_parameters(optimizer_parameters[1]), + optimizer_parameters[1]["lr"], + parameters=self.secondary_params, + ) + + self.param_groups = self.primary_optimizer.param_groups + + def first_step(self): + self.secondary_optimizer.step() + self.secondary_optimizer.zero_grad() + self.primary_optimizer.zero_grad() + + def step(self): + # Update param groups to display the correct learning rate + self.param_groups = self.primary_optimizer.param_groups + self.primary_optimizer.step() + + def zero_grad(self, set_to_none=False): + self.primary_optimizer.zero_grad(set_to_none) + self.secondary_optimizer.zero_grad(set_to_none) + + def load_state_dict(self, state_dict): + self.primary_optimizer.load_state_dict(state_dict[0]) + self.secondary_optimizer.load_state_dict(state_dict[1]) + + def state_dict(self): + return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()] + + @staticmethod + def split_model_parameters(model_params: Generator) -> list: + primary_params = [] + secondary_params = [] + for name, param in model_params: + if param.requires_grad: + if name == "capacitron_vae_layer.beta": + secondary_params.append(param) + else: + primary_params.append(param) + return [iter(primary_params), iter(secondary_params)] + + @staticmethod + def extract_optimizer_parameters(params: dict) -> dict: + """Extract parameters that are not the learning rate""" + return {k: v for k, v in params.items() if k != "lr"} diff --git a/TTS/utils/distribute.py b/TTS/utils/distribute.py new file mode 100644 index 0000000..a51ef76 --- /dev/null +++ b/TTS/utils/distribute.py @@ -0,0 +1,20 @@ +# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py +import torch +import torch.distributed as dist + + +def reduce_tensor(tensor, num_gpus): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.reduce_op.SUM) + rt /= num_gpus + return rt + + +def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): + assert torch.cuda.is_available(), "Distributed mode requires CUDA." + + # Set cuda device so everything is done on the right GPU. + torch.cuda.set_device(rank % torch.cuda.device_count()) + + # Initialize distributed communication + dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name) diff --git a/TTS/utils/download.py b/TTS/utils/download.py new file mode 100644 index 0000000..3f06b57 --- /dev/null +++ b/TTS/utils/download.py @@ -0,0 +1,206 @@ +# Adapted from https://github.com/pytorch/audio/ + +import hashlib +import logging +import os +import tarfile +import urllib +import urllib.request +import zipfile +from os.path import expanduser +from typing import Any, Iterable, List, Optional + +from torch.utils.model_zoo import tqdm + + +def stream_url( + url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True +) -> Iterable: + """Stream url by chunk + + Args: + url (str): Url. + start_byte (int or None, optional): Start streaming at that point (Default: ``None``). + block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + """ + + # If we already have the whole file, there is no need to download it again + req = urllib.request.Request(url, method="HEAD") + with urllib.request.urlopen(req) as response: + url_size = int(response.info().get("Content-Length", -1)) + if url_size == start_byte: + return + + req = urllib.request.Request(url) + if start_byte: + req.headers["Range"] = "bytes={}-".format(start_byte) + + with urllib.request.urlopen(req) as upointer, tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + total=url_size, + disable=not progress_bar, + ) as pbar: + num_bytes = 0 + while True: + chunk = upointer.read(block_size) + if not chunk: + break + yield chunk + num_bytes += len(chunk) + pbar.update(len(chunk)) + + +def download_url( + url: str, + download_folder: str, + filename: Optional[str] = None, + hash_value: Optional[str] = None, + hash_type: str = "sha256", + progress_bar: bool = True, + resume: bool = False, +) -> None: + """Download file to disk. + + Args: + url (str): Url. + download_folder (str): Folder to download file. + filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url + (Default: ``None``). + hash_value (str or None, optional): Hash for url (Default: ``None``). + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + resume (bool, optional): Enable resuming download (Default: ``False``). + """ + + req = urllib.request.Request(url, method="HEAD") + req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with + + # Detect filename + filename = filename or req_info.get_filename() or os.path.basename(url) + filepath = os.path.join(download_folder, filename) + if resume and os.path.exists(filepath): + mode = "ab" + local_size: Optional[int] = os.path.getsize(filepath) + + elif not resume and os.path.exists(filepath): + raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath)) + else: + mode = "wb" + local_size = None + + if hash_value and local_size == int(req_info.get("Content-Length", -1)): + with open(filepath, "rb") as file_obj: + if validate_file(file_obj, hash_value, hash_type): + return + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) + + with open(filepath, mode) as fpointer: + for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): + fpointer.write(chunk) + + with open(filepath, "rb") as file_obj: + if hash_value and not validate_file(file_obj, hash_value, hash_type): + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) + + +def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool: + """Validate a given file object with its hash. + + Args: + file_obj: File object to read from. + hash_value (str): Hash for url. + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + + Returns: + bool: return True if its a valid file, else False. + """ + + if hash_type == "sha256": + hash_func = hashlib.sha256() + elif hash_type == "md5": + hash_func = hashlib.md5() + else: + raise ValueError + + while True: + # Read by chunk to avoid filling memory + chunk = file_obj.read(1024**2) + if not chunk: + break + hash_func.update(chunk) + + return hash_func.hexdigest() == hash_value + + +def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: + """Extract archive. + Args: + from_path (str): the path of the archive. + to_path (str or None, optional): the root path of the extraced files (directory of from_path) + (Default: ``None``) + overwrite (bool, optional): overwrite existing files (Default: ``False``) + + Returns: + list: List of paths to extracted files even if not overwritten. + """ + + if to_path is None: + to_path = os.path.dirname(from_path) + + try: + with tarfile.open(from_path, "r") as tar: + logging.info("Opened tar file %s.", from_path) + files = [] + for file_ in tar: # type: Any + file_path = os.path.join(to_path, file_.name) + if file_.isfile(): + files.append(file_path) + if os.path.exists(file_path): + logging.info("%s already extracted.", file_path) + if not overwrite: + continue + tar.extract(file_, to_path) + return files + except tarfile.ReadError: + pass + + try: + with zipfile.ZipFile(from_path, "r") as zfile: + logging.info("Opened zip file %s.", from_path) + files = zfile.namelist() + for file_ in files: + file_path = os.path.join(to_path, file_) + if os.path.exists(file_path): + logging.info("%s already extracted.", file_path) + if not overwrite: + continue + zfile.extract(file_, to_path) + return files + except zipfile.BadZipFile: + pass + + raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.") + + +def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str): + """Download dataset from kaggle. + Args: + dataset_path (str): + This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning' + dataset_name (str): Name of the folder the dataset will be saved in. + output_path (str): Path of the location you want the dataset folder to be saved to. + """ + data_path = os.path.join(output_path, dataset_name) + try: + import kaggle # pylint: disable=import-outside-toplevel + + kaggle.api.authenticate() + print(f"""\nDownloading {dataset_name}...""") + kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) + except OSError: + print( + f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}""" + ) diff --git a/TTS/utils/downloaders.py b/TTS/utils/downloaders.py new file mode 100644 index 0000000..104dc7b --- /dev/null +++ b/TTS/utils/downloaders.py @@ -0,0 +1,126 @@ +import os +from typing import Optional + +from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive + + +def download_ljspeech(path: str): + """Download and extract LJSpeech dataset + + Args: + path (str): path to the directory where the dataset will be stored. + """ + os.makedirs(path, exist_ok=True) + url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_vctk(path: str, use_kaggle: Optional[bool] = False): + """Download and extract VCTK dataset. + + Args: + path (str): path to the directory where the dataset will be stored. + + use_kaggle (bool, optional): Downloads vctk dataset from kaggle. Is generally faster. Defaults to False. + """ + if use_kaggle: + download_kaggle_dataset("mfekadu/english-multispeaker-corpus-for-voice-cloning", "VCTK", path) + else: + os.makedirs(path, exist_ok=True) + url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_tweb(path: str): + """Download and extract Tweb dataset + + Args: + path (str): Path to the directory where the dataset will be stored. + """ + download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path) + + +def download_libri_tts(path: str, subset: Optional[str] = "all"): + """Download and extract libri tts dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + + subset (str, optional): Name of the subset to download. If you only want to download a certain + portion specify it here. Defaults to 'all'. + """ + + subset_dict = { + "libri-tts-clean-100": "http://www.openslr.org/resources/60/train-clean-100.tar.gz", + "libri-tts-clean-360": "http://www.openslr.org/resources/60/train-clean-360.tar.gz", + "libri-tts-other-500": "http://www.openslr.org/resources/60/train-other-500.tar.gz", + "libri-tts-dev-clean": "http://www.openslr.org/resources/60/dev-clean.tar.gz", + "libri-tts-dev-other": "http://www.openslr.org/resources/60/dev-other.tar.gz", + "libri-tts-test-clean": "http://www.openslr.org/resources/60/test-clean.tar.gz", + "libri-tts-test-other": "http://www.openslr.org/resources/60/test-other.tar.gz", + } + + os.makedirs(path, exist_ok=True) + if subset == "all": + for sub, val in subset_dict.items(): + print(f" > Downloading {sub}...") + download_url(val, path) + basename = os.path.basename(val) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + print(" > All subsets downloaded") + else: + url = subset_dict[subset] + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_thorsten_de(path: str): + """Download and extract Thorsten german male voice dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + """ + os.makedirs(path, exist_ok=True) + url = "https://www.openslr.org/resources/95/thorsten-de_v02.tgz" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_mailabs(path: str, language: str = "english"): + """Download and extract Mailabs dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + + language (str): Language subset to download. Defaults to english. + """ + language_dict = { + "english": "https://data.solak.de/data/Training/stt_tts/en_US.tgz", + "german": "https://data.solak.de/data/Training/stt_tts/de_DE.tgz", + "french": "https://data.solak.de/data/Training/stt_tts/fr_FR.tgz", + "italian": "https://data.solak.de/data/Training/stt_tts/it_IT.tgz", + "spanish": "https://data.solak.de/data/Training/stt_tts/es_ES.tgz", + } + os.makedirs(path, exist_ok=True) + url = language_dict[language] + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py new file mode 100644 index 0000000..4fa4741 --- /dev/null +++ b/TTS/utils/generic_utils.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- +import datetime +import importlib +import logging +import os +import re +import subprocess +import sys +from pathlib import Path +from typing import Dict + +import fsspec +import torch + + +def to_cuda(x: torch.Tensor) -> torch.Tensor: + if x is None: + return None + if torch.is_tensor(x): + x = x.contiguous() + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) + return x + + +def get_cuda(): + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return use_cuda, device + + +def get_git_branch(): + try: + out = subprocess.check_output(["git", "branch"]).decode("utf8") + current = next(line for line in out.split("\n") if line.startswith("*")) + current.replace("* ", "") + except subprocess.CalledProcessError: + current = "inside_docker" + except (FileNotFoundError, StopIteration) as e: + current = "unknown" + return current + + +def get_commit_hash(): + """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" + # try: + # subprocess.check_output(['git', 'diff-index', '--quiet', + # 'HEAD']) # Verify client is clean + # except: + # raise RuntimeError( + # " !! Commit before training to get the commit hash.") + try: + commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() + # Not copying .git folder into docker container + except (subprocess.CalledProcessError, FileNotFoundError): + commit = "0000000" + return commit + + +def get_experiment_folder_path(root_path, model_name): + """Get an experiment folder path with the current date and time""" + date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") + commit_hash = get_commit_hash() + output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) + return output_folder + + +def remove_experiment_folder(experiment_path): + """Check folder if there is a checkpoint, otherwise remove the folder""" + fs = fsspec.get_mapper(experiment_path).fs + checkpoint_files = fs.glob(experiment_path + "/*.pth") + if not checkpoint_files: + if fs.exists(experiment_path): + fs.rm(experiment_path, recursive=True) + print(" ! Run is removed from {}".format(experiment_path)) + else: + print(" ! Run is kept in {}".format(experiment_path)) + + +def count_parameters(model): + r"""Count number of trainable parameters in a network""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def to_camel(text): + text = text.capitalize() + text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + text = text.replace("Tts", "TTS") + text = text.replace("vc", "VC") + return text + + +def find_module(module_path: str, module_name: str) -> object: + module_name = module_name.lower() + module = importlib.import_module(module_path + "." + module_name) + class_name = to_camel(module_name) + return getattr(module, class_name) + + +def import_class(module_path: str) -> object: + """Import a class from a module path. + + Args: + module_path (str): The module path of the class. + + Returns: + object: The imported class. + """ + class_name = module_path.split(".")[-1] + module_path = ".".join(module_path.split(".")[:-1]) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_import_path(obj: object) -> str: + """Get the import path of a class. + + Args: + obj (object): The class object. + + Returns: + str: The import path of the class. + """ + return ".".join([type(obj).__module__, type(obj).__name__]) + + +def get_user_data_dir(appname): + TTS_HOME = os.environ.get("TTS_HOME") + XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME") + if TTS_HOME is not None: + ans = Path(TTS_HOME).expanduser().resolve(strict=False) + elif XDG_DATA_HOME is not None: + ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False) + elif sys.platform == "win32": + import winreg # pylint: disable=import-outside-toplevel + + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" + ) + dir_, _ = winreg.QueryValueEx(key, "Local AppData") + ans = Path(dir_).resolve(strict=False) + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() + else: + ans = Path.home().joinpath(".local/share") + return ans.joinpath(appname) + + +def set_init_dict(model_dict, checkpoint_state, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + for k, v in checkpoint_state.items(): + if k not in model_dict: + print(" | > Layer missing in the model definition: {}".format(k)) + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} + # 2. filter out different size layers + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} + # 3. skip reinit layers + if c.has("reinit_layers") and c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + return model_dict + + +def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: + """Format kwargs to hande auxilary inputs to models. + + Args: + def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`. + kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model. + + Returns: + Dict: arguments with formatted auxilary inputs. + """ + kwargs = kwargs.copy() + for name in def_args: + if name not in kwargs or kwargs[name] is None: + kwargs[name] = def_args[name] + return kwargs + + +class KeepAverage: + def __init__(self): + self.avg_values = {} + self.iters = {} + + def __getitem__(self, key): + return self.avg_values[key] + + def items(self): + return self.avg_values.items() + + def add_value(self, name, init_val=0, init_iter=0): + self.avg_values[name] = init_val + self.iters[name] = init_iter + + def update_value(self, name, value, weighted_avg=False): + if name not in self.avg_values: + # add value if not exist before + self.add_value(name, init_val=value) + else: + # else update existing value + if weighted_avg: + self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value + self.iters[name] += 1 + else: + self.avg_values[name] = self.avg_values[name] * self.iters[name] + value + self.iters[name] += 1 + self.avg_values[name] /= self.iters[name] + + def add_values(self, name_dict): + for key, value in name_dict.items(): + self.add_value(key, init_val=value) + + def update_values(self, value_dict): + for key, value in value_dict.items(): + self.update_value(key, value) + + +def get_timestamp(): + return datetime.now().strftime("%y%m%d-%H%M%S") + + +def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): + lg = logging.getLogger(logger_name) + formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") + lg.setLevel(level) + if tofile: + log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) + fh = logging.FileHandler(log_file, mode="w") + fh.setFormatter(formatter) + lg.addHandler(fh) + if screen: + sh = logging.StreamHandler() + sh.setFormatter(formatter) + lg.addHandler(sh) diff --git a/TTS/utils/io.py b/TTS/utils/io.py new file mode 100644 index 0000000..3107ba6 --- /dev/null +++ b/TTS/utils/io.py @@ -0,0 +1,70 @@ +import os +import pickle as pickle_tts +from typing import Any, Callable, Dict, Union + +import fsspec +import torch + +from TTS.utils.generic_utils import get_user_data_dir + + +class RenamingUnpickler(pickle_tts.Unpickler): + """Overload default pickler to solve module renaming problem""" + + def find_class(self, module, name): + return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) + + +class AttrDict(dict): + """A custom dict which converts dict keys + to class attributes""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def load_fsspec( + path: str, + map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, + cache: bool = True, + **kwargs, +) -> Any: + """Like torch.load but can load from other locations (e.g. s3:// , gs://). + + Args: + path: Any path or url supported by fsspec. + map_location: torch.device or str. + cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. + **kwargs: Keyword arguments forwarded to torch.load. + + Returns: + Object stored in path. + """ + is_local = os.path.isdir(path) or os.path.isfile(path) + if cache and not is_local: + with fsspec.open( + f"filecache::{path}", + filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, + mode="rb", + ) as f: + return torch.load(f, map_location=map_location, **kwargs) + else: + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location=map_location, **kwargs) + + +def load_checkpoint( + model, checkpoint_path, use_cuda=False, eval=False, cache=False +): # pylint: disable=redefined-builtin + try: + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + except ModuleNotFoundError: + pickle_tts.Unpickler = RenamingUnpickler + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) + model.load_state_dict(state["model"]) + if use_cuda: + model.cuda() + if eval: + model.eval() + return model, state diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py new file mode 100644 index 0000000..3a527f4 --- /dev/null +++ b/TTS/utils/manage.py @@ -0,0 +1,621 @@ +import json +import os +import re +import tarfile +import zipfile +from pathlib import Path +from shutil import copyfile, rmtree +from typing import Dict, List, Tuple + +import fsspec +import requests +from tqdm import tqdm + +from TTS.config import load_config, read_json_with_comments +from TTS.utils.generic_utils import get_user_data_dir + +LICENSE_URLS = { + "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", + "mpl": "https://www.mozilla.org/en-US/MPL/2.0/", + "mpl2": "https://www.mozilla.org/en-US/MPL/2.0/", + "mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/", + "mit": "https://choosealicense.com/licenses/mit/", + "apache 2.0": "https://choosealicense.com/licenses/apache-2.0/", + "apache2": "https://choosealicense.com/licenses/apache-2.0/", + "cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/", + "cpml": "https://coqui.ai/cpml.txt", +} + + +class ModelManager(object): + tqdm_progress = None + """Manage TTS models defined in .models.json. + It provides an interface to list and download + models defines in '.model.json' + + Models are downloaded under '.TTS' folder in the user's + home path. + + Args: + models_file (str): path to .model.json file. Defaults to None. + output_prefix (str): prefix to `tts` to download models. Defaults to None + progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. + verbose (bool): print info. Defaults to True. + """ + + def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True): + super().__init__() + self.progress_bar = progress_bar + self.verbose = verbose + if output_prefix is None: + self.output_prefix = get_user_data_dir("tts") + else: + self.output_prefix = os.path.join(output_prefix, "tts") + self.models_dict = None + if models_file is not None: + self.read_models_file(models_file) + else: + # try the default location + path = Path(__file__).parent / "../.models.json" + self.read_models_file(path) + + def read_models_file(self, file_path): + """Read .models.json as a dict + + Args: + file_path (str): path to .models.json. + """ + self.models_dict = read_json_with_comments(file_path) + + def _list_models(self, model_type, model_count=0): + if self.verbose: + print("\n Name format: type/language/dataset/model") + model_list = [] + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + output_path = os.path.join(self.output_prefix, model_full_name) + if self.verbose: + if os.path.exists(output_path): + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") + else: + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") + model_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + return model_list + + def _list_for_model_type(self, model_type): + models_name_list = [] + model_count = 1 + models_name_list.extend(self._list_models(model_type, model_count)) + return models_name_list + + def list_models(self): + models_name_list = [] + model_count = 1 + for model_type in self.models_dict: + model_list = self._list_models(model_type, model_count) + models_name_list.extend(model_list) + return models_name_list + + def model_info_by_idx(self, model_query): + """Print the description of the model from .models.json file using model_idx + + Args: + model_query (str): / + """ + model_name_list = [] + model_type, model_query_idx = model_query.split("/") + try: + model_query_idx = int(model_query_idx) + if model_query_idx <= 0: + print("> model_query_idx should be a positive integer!") + return + except: + print("> model_query_idx should be an integer!") + return + model_count = 0 + if model_type in self.models_dict: + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + else: + print(f"> model_type {model_type} does not exist in the list.") + return + if model_query_idx > model_count: + print(f"model query idx exceeds the number of available models [{model_count}] ") + else: + model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") + print(f"> model type : {model_type}") + print(f"> language supported : {lang}") + print(f"> dataset used : {dataset}") + print(f"> model name : {model}") + if "description" in self.models_dict[model_type][lang][dataset][model]: + print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}") + else: + print("> description : coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}") + + def model_info_by_full_name(self, model_query_name): + """Print the description of the model from .models.json file using model_full_name + + Args: + model_query_name (str): Format is /// + """ + model_type, lang, dataset, model = model_query_name.split("/") + if model_type in self.models_dict: + if lang in self.models_dict[model_type]: + if dataset in self.models_dict[model_type][lang]: + if model in self.models_dict[model_type][lang][dataset]: + print(f"> model type : {model_type}") + print(f"> language supported : {lang}") + print(f"> dataset used : {dataset}") + print(f"> model name : {model}") + if "description" in self.models_dict[model_type][lang][dataset][model]: + print( + f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}" + ) + else: + print("> description : coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + print( + f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}" + ) + else: + print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.") + else: + print(f"> dataset {dataset} does not exist for {model_type}/{lang}.") + else: + print(f"> lang {lang} does not exist for {model_type}.") + else: + print(f"> model_type {model_type} does not exist in the list.") + + def list_tts_models(self): + """Print all `TTS` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("tts_models") + + def list_vocoder_models(self): + """Print all the `vocoder` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("vocoder_models") + + def list_vc_models(self): + """Print all the voice conversion models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("voice_conversion_models") + + def list_langs(self): + """Print all the available languages""" + print(" Name format: type/language") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + print(f" >: {model_type}/{lang} ") + + def list_datasets(self): + """Print all the datasets""" + print(" Name format: type/language/dataset") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + print(f" >: {model_type}/{lang}/{dataset}") + + @staticmethod + def print_model_license(model_item: Dict): + """Print the license of a model + + Args: + model_item (dict): model item in the models.json + """ + if "license" in model_item and model_item["license"].strip() != "": + print(f" > Model's license - {model_item['license']}") + if model_item["license"].lower() in LICENSE_URLS: + print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.") + else: + print(" > Check https://opensource.org/licenses for more info.") + else: + print(" > Model's license - No license information available") + + def _download_github_model(self, model_item: Dict, output_path: str): + if isinstance(model_item["github_rls_url"], list): + self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) + + def _download_hf_model(self, model_item: Dict, output_path: str): + if isinstance(model_item["hf_url"], list): + self._download_model_files(model_item["hf_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar) + + def download_fairseq_model(self, model_name, output_path): + URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/" + _, lang, _, _ = model_name.split("/") + model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz") + self._download_tar_file(model_download_uri, output_path, self.progress_bar) + + @staticmethod + def set_model_url(model_item: Dict): + model_item["model_url"] = None + if "github_rls_url" in model_item: + model_item["model_url"] = model_item["github_rls_url"] + elif "hf_url" in model_item: + model_item["model_url"] = model_item["hf_url"] + elif "fairseq" in model_item["model_name"]: + model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/" + elif "xtts" in model_item["model_name"]: + model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/" + return model_item + + def _set_model_item(self, model_name): + # fetch model info from the dict + if "fairseq" in model_name: + model_type = "tts_models" + lang = model_name.split("/")[1] + model_item = { + "model_type": "tts_models", + "license": "CC BY-NC 4.0", + "default_vocoder": None, + "author": "fairseq", + "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", + } + model_item["model_name"] = model_name + elif "xtts" in model_name and len(model_name.split("/")) != 4: + # loading xtts models with only model name (e.g. xtts_v2.0.2) + # check model name has the version number with regex + version_regex = r"v\d+\.\d+\.\d+" + if re.search(version_regex, model_name): + model_version = model_name.split("_")[-1] + else: + model_version = "main" + model_type = "tts_models" + lang = "multilingual" + dataset = "multi-dataset" + model = model_name + model_item = { + "default_vocoder": None, + "license": "CPML", + "contact": "info@coqui.ai", + "tos_required": True, + "hf_url": [ + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth", + ], + } + else: + # get model from models.json + model_type, lang, dataset, model = model_name.split("/") + model_item = self.models_dict[model_type][lang][dataset][model] + model_item["model_type"] = model_type + + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + md5hash = model_item["model_hash"] if "model_hash" in model_item else None + model_item = self.set_model_url(model_item) + return model_item, model_full_name, model, md5hash + + @staticmethod + def ask_tos(model_full_path): + """Ask the user to agree to the terms of service""" + tos_path = os.path.join(model_full_path, "tos_agreed.txt") + print(" > You must confirm the following:") + print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"') + print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]') + answer = input(" | | > ") + if answer.lower() == "y": + with open(tos_path, "w", encoding="utf-8") as f: + f.write("I have read, understood and agreed to the Terms and Conditions.") + return True + return False + + @staticmethod + def tos_agreed(model_item, model_full_path): + """Check if the user has agreed to the terms of service""" + if "tos_required" in model_item and model_item["tos_required"]: + tos_path = os.path.join(model_full_path, "tos_agreed.txt") + if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1": + return True + return False + return True + + def create_dir_and_download_model(self, model_name, model_item, output_path): + os.makedirs(output_path, exist_ok=True) + # handle TOS + if not self.tos_agreed(model_item, output_path): + if not self.ask_tos(output_path): + os.rmdir(output_path) + raise Exception(" [!] You must agree to the terms of service to use this model.") + print(f" > Downloading model to {output_path}") + try: + if "fairseq" in model_name: + self.download_fairseq_model(model_name, output_path) + elif "github_rls_url" in model_item: + self._download_github_model(model_item, output_path) + elif "hf_url" in model_item: + self._download_hf_model(model_item, output_path) + + except requests.RequestException as e: + print(f" > Failed to download the model file to {output_path}") + rmtree(output_path) + raise e + self.print_model_license(model_item=model_item) + + def check_if_configs_are_equal(self, model_name, model_item, output_path): + with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f: + config_local = json.load(f) + remote_url = None + for url in model_item["hf_url"]: + if "config.json" in url: + remote_url = url + break + + with fsspec.open(remote_url, "r", encoding="utf-8") as f: + config_remote = json.load(f) + + if not config_local == config_remote: + print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") + self.create_dir_and_download_model(model_name, model_item, output_path) + + def download_model(self, model_name): + """Download model files given the full model name. + Model name is in the format + 'type/language/dataset/model' + e.g. 'tts_model/en/ljspeech/tacotron' + + Every model must have the following files: + - *.pth : pytorch model checkpoint file. + - config.json : model config file. + - scale_stats.npy (if exist): scale values for preprocessing. + + Args: + model_name (str): model name as explained above. + """ + model_item, model_full_name, model, md5sum = self._set_model_item(model_name) + # set the model specific output path + output_path = os.path.join(self.output_prefix, model_full_name) + if os.path.exists(output_path): + if md5sum is not None: + md5sum_file = os.path.join(output_path, "hash.md5") + if os.path.isfile(md5sum_file): + with open(md5sum_file, mode="r") as f: + if not f.read() == md5sum: + print(f" > {model_name} has been updated, clearing model cache...") + self.create_dir_and_download_model(model_name, model_item, output_path) + else: + print(f" > {model_name} is already downloaded.") + else: + print(f" > {model_name} has been updated, clearing model cache...") + self.create_dir_and_download_model(model_name, model_item, output_path) + # if the configs are different, redownload it + # ToDo: we need a better way to handle it + if "xtts" in model_name: + try: + self.check_if_configs_are_equal(model_name, model_item, output_path) + except: + pass + else: + print(f" > {model_name} is already downloaded.") + else: + self.create_dir_and_download_model(model_name, model_item, output_path) + + # find downloaded files + output_model_path = output_path + output_config_path = None + if ( + model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name + ): # TODO:This is stupid but don't care for now. + output_model_path, output_config_path = self._find_files(output_path) + # update paths in the config.json + self._update_paths(output_path, output_config_path) + return output_model_path, output_config_path, model_item + + @staticmethod + def _find_files(output_path: str) -> Tuple[str, str]: + """Find the model and config files in the output path + + Args: + output_path (str): path to the model files + + Returns: + Tuple[str, str]: path to the model file and config file + """ + model_file = None + config_file = None + for file_name in os.listdir(output_path): + if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]: + model_file = os.path.join(output_path, file_name) + elif file_name == "config.json": + config_file = os.path.join(output_path, file_name) + if model_file is None: + raise ValueError(" [!] Model file not found in the output path") + if config_file is None: + raise ValueError(" [!] Config file not found in the output path") + return model_file, config_file + + @staticmethod + def _find_speaker_encoder(output_path: str) -> str: + """Find the speaker encoder file in the output path + + Args: + output_path (str): path to the model files + + Returns: + str: path to the speaker encoder file + """ + speaker_encoder_file = None + for file_name in os.listdir(output_path): + if file_name in ["model_se.pth", "model_se.pth.tar"]: + speaker_encoder_file = os.path.join(output_path, file_name) + return speaker_encoder_file + + def _update_paths(self, output_path: str, config_path: str) -> None: + """Update paths for certain files in config.json after download. + + Args: + output_path (str): local path the model is downloaded to. + config_path (str): local config.json path. + """ + output_stats_path = os.path.join(output_path, "scale_stats.npy") + output_d_vector_file_path = os.path.join(output_path, "speakers.json") + output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth") + output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") + output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth") + speaker_encoder_config_path = os.path.join(output_path, "config_se.json") + speaker_encoder_model_path = self._find_speaker_encoder(output_path) + + # update the scale_path.npy file path in the model config.json + self._update_path("audio.stats_path", output_stats_path, config_path) + + # update the speakers.json file path in the model config.json to the current path + self._update_path("d_vector_file", output_d_vector_file_path, config_path) + self._update_path("d_vector_file", output_d_vector_file_pth_path, config_path) + self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path) + self._update_path("model_args.d_vector_file", output_d_vector_file_pth_path, config_path) + + # update the speaker_ids.json file path in the model config.json to the current path + self._update_path("speakers_file", output_speaker_ids_file_path, config_path) + self._update_path("speakers_file", output_speaker_ids_file_pth_path, config_path) + self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) + self._update_path("model_args.speakers_file", output_speaker_ids_file_pth_path, config_path) + + # update the speaker_encoder file path in the model config.json to the current path + self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path) + self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) + + @staticmethod + def _update_path(field_name, new_path, config_path): + """Update the path in the model config.json for the current environment after download""" + if new_path and os.path.exists(new_path): + config = load_config(config_path) + field_names = field_name.split(".") + if len(field_names) > 1: + # field name points to a sub-level field + sub_conf = config + for fd in field_names[:-1]: + if fd in sub_conf: + sub_conf = sub_conf[fd] + else: + return + if isinstance(sub_conf[field_names[-1]], list): + sub_conf[field_names[-1]] = [new_path] + else: + sub_conf[field_names[-1]] = new_path + else: + # field name points to a top-level field + if not field_name in config: + return + if isinstance(config[field_name], list): + config[field_name] = [new_path] + else: + config[field_name] = new_path + config.save_json(config_path) + + @staticmethod + def _download_zip_file(file_url, output_folder, progress_bar): + """Download the github releases""" + # download the file + r = requests.get(file_url, stream=True) + # extract the file + try: + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + if progress_bar: + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) + with open(temp_zip_name, "wb") as file: + for data in r.iter_content(block_size): + if progress_bar: + ModelManager.tqdm_progress.update(len(data)) + file.write(data) + with zipfile.ZipFile(temp_zip_name) as z: + z.extractall(output_folder) + os.remove(temp_zip_name) # delete zip after extract + except zipfile.BadZipFile: + print(f" > Error: Bad zip file - {file_url}") + raise zipfile.BadZipFile # pylint: disable=raise-missing-from + # move the files to the outer path + for file_path in z.namelist(): + src_path = os.path.join(output_folder, file_path) + if os.path.isfile(src_path): + dst_path = os.path.join(output_folder, os.path.basename(file_path)) + if src_path != dst_path: + copyfile(src_path, dst_path) + # remove redundant (hidden or not) folders + for file_path in z.namelist(): + if os.path.isdir(os.path.join(output_folder, file_path)): + rmtree(os.path.join(output_folder, file_path)) + + @staticmethod + def _download_tar_file(file_url, output_folder, progress_bar): + """Download the github releases""" + # download the file + r = requests.get(file_url, stream=True) + # extract the file + try: + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + if progress_bar: + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) + with open(temp_tar_name, "wb") as file: + for data in r.iter_content(block_size): + if progress_bar: + ModelManager.tqdm_progress.update(len(data)) + file.write(data) + with tarfile.open(temp_tar_name) as t: + t.extractall(output_folder) + tar_names = t.getnames() + os.remove(temp_tar_name) # delete tar after extract + except tarfile.ReadError: + print(f" > Error: Bad tar file - {file_url}") + raise tarfile.ReadError # pylint: disable=raise-missing-from + # move the files to the outer path + for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): + src_path = os.path.join(output_folder, tar_names[0], file_path) + dst_path = os.path.join(output_folder, os.path.basename(file_path)) + if src_path != dst_path: + copyfile(src_path, dst_path) + # remove the extracted folder + rmtree(os.path.join(output_folder, tar_names[0])) + + @staticmethod + def _download_model_files(file_urls, output_folder, progress_bar): + """Download the github releases""" + for file_url in file_urls: + # download the file + r = requests.get(file_url, stream=True) + # extract the file + bease_filename = file_url.split("/")[-1] + temp_zip_name = os.path.join(output_folder, bease_filename) + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + with open(temp_zip_name, "wb") as file: + if progress_bar: + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + for data in r.iter_content(block_size): + if progress_bar: + ModelManager.tqdm_progress.update(len(data)) + file.write(data) + + @staticmethod + def _check_dict_key(my_dict, key): + if key in my_dict.keys() and my_dict[key] is not None: + if not isinstance(key, str): + return True + if isinstance(key, str) and len(my_dict[key]) > 0: + return True + return False diff --git a/TTS/utils/radam.py b/TTS/utils/radam.py new file mode 100644 index 0000000..cbd1499 --- /dev/null +++ b/TTS/utils/radam.py @@ -0,0 +1,105 @@ +# modified from https://github.com/LiyuanLucasLiu/RAdam + +import math + +import torch +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if eps < 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)] + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # pylint: disable=useless-super-delegation + super().__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state["step"]) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + + return loss diff --git a/TTS/utils/samplers.py b/TTS/utils/samplers.py new file mode 100644 index 0000000..b08a763 --- /dev/null +++ b/TTS/utils/samplers.py @@ -0,0 +1,201 @@ +import math +import random +from typing import Callable, List, Union + +from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler + + +class SubsetSampler(Sampler): + """ + Samples elements sequentially from a given list of indices. + + Args: + indices (list): a sequence of indices + """ + + def __init__(self, indices): + super().__init__(indices) + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in range(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class PerfectBatchSampler(Sampler): + """ + Samples a mini-batch of indices for a balanced class batching + + Args: + dataset_items(list): dataset items to sample from. + classes (list): list of classes of dataset_items to sample from. + batch_size (int): total number of samples to be sampled in a mini-batch. + num_gpus (int): number of GPU in the data parallel mode. + shuffle (bool): if True, samples randomly, otherwise samples sequentially. + drop_last (bool): if True, drops last incomplete batch. + """ + + def __init__( + self, + dataset_items, + classes, + batch_size, + num_classes_in_batch, + num_gpus=1, + shuffle=True, + drop_last=False, + label_key="class_name", + ): + super().__init__(dataset_items) + assert ( + batch_size % (num_classes_in_batch * num_gpus) == 0 + ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." + + label_indices = {} + for idx, item in enumerate(dataset_items): + label = item[label_key] + if label not in label_indices.keys(): + label_indices[label] = [idx] + else: + label_indices[label].append(idx) + + if shuffle: + self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] + else: + self._samplers = [SubsetSampler(label_indices[key]) for key in classes] + + self._batch_size = batch_size + self._drop_last = drop_last + self._dp_devices = num_gpus + self._num_classes_in_batch = num_classes_in_batch + + def __iter__(self): + batch = [] + if self._num_classes_in_batch != len(self._samplers): + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + else: + valid_samplers_idx = None + + iters = [iter(s) for s in self._samplers] + done = False + + while True: + b = [] + for i, it in enumerate(iters): + if valid_samplers_idx is not None and i not in valid_samplers_idx: + continue + idx = next(it, None) + if idx is None: + done = True + break + b.append(idx) + if done: + break + batch += b + if len(batch) == self._batch_size: + yield batch + batch = [] + if valid_samplers_idx is not None: + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + + if not self._drop_last: + if len(batch) > 0: + groups = len(batch) // self._num_classes_in_batch + if groups % self._dp_devices == 0: + yield batch + else: + batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] + if len(batch) > 0: + yield batch + + def __len__(self): + class_batch_size = self._batch_size // self._num_classes_in_batch + return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) + + +def identity(x): + return x + + +class SortedSampler(Sampler): + """Samples elements sequentially, always in the same order. + + Taken from https://github.com/PetrochukM/PyTorch-NLP + + Args: + data (iterable): Iterable data. + sort_key (callable): Specifies a function of one argument that is used to extract a + numerical comparison key from each list element. + + Example: + >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + + """ + + def __init__(self, data, sort_key: Callable = identity): + super().__init__(data) + self.data = data + self.sort_key = sort_key + zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] + zip_ = sorted(zip_, key=lambda r: r[1]) + self.sorted_indexes = [item[0] for item in zip_] + + def __iter__(self): + return iter(self.sorted_indexes) + + def __len__(self): + return len(self.data) + + +class BucketBatchSampler(BatchSampler): + """Bucket batch sampler + + Adapted from https://github.com/PetrochukM/PyTorch-NLP + + Args: + sampler (torch.data.utils.sampler.Sampler): + batch_size (int): Size of mini-batch. + drop_last (bool): If `True` the sampler will drop the last batch if its size would be less + than `batch_size`. + data (list): List of data samples. + sort_key (callable, optional): Callable to specify a comparison key for sorting. + bucket_size_multiplier (int, optional): Buckets are of size + `batch_size * bucket_size_multiplier`. + + Example: + >>> sampler = WeightedRandomSampler(weights, len(weights)) + >>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True) + """ + + def __init__( + self, + sampler, + data, + batch_size, + drop_last, + sort_key: Union[Callable, List] = identity, + bucket_size_multiplier=100, + ): + super().__init__(sampler, batch_size, drop_last) + self.data = data + self.sort_key = sort_key + _bucket_size = batch_size * bucket_size_multiplier + if hasattr(sampler, "__len__"): + _bucket_size = min(_bucket_size, len(sampler)) + self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) + + def __iter__(self): + for idxs in self.bucket_sampler: + bucket_data = [self.data[idx] for idx in idxs] + sorted_sampler = SortedSampler(bucket_data, self.sort_key) + for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): + sorted_idxs = [idxs[i] for i in batch_idx] + yield sorted_idxs + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + return math.ceil(len(self.sampler) / self.batch_size) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py new file mode 100644 index 0000000..b98647c --- /dev/null +++ b/TTS/utils/synthesizer.py @@ -0,0 +1,505 @@ +import os +import time +from typing import List + +import numpy as np +import pysbd +import torch +from torch import nn + +from TTS.config import load_config +from TTS.tts.configs.vits_config import VitsConfig +from TTS.tts.models import setup_model as setup_tts_model +from TTS.tts.models.vits import Vits + +# pylint: disable=unused-wildcard-import +# pylint: disable=wildcard-import +from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import save_wav +from TTS.vc.models import setup_model as setup_vc_model +from TTS.vocoder.models import setup_model as setup_vocoder_model +from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input + + +class Synthesizer(nn.Module): + def __init__( + self, + tts_checkpoint: str = "", + tts_config_path: str = "", + tts_speakers_file: str = "", + tts_languages_file: str = "", + vocoder_checkpoint: str = "", + vocoder_config: str = "", + encoder_checkpoint: str = "", + encoder_config: str = "", + vc_checkpoint: str = "", + vc_config: str = "", + model_dir: str = "", + voice_dir: str = None, + use_cuda: bool = False, + ) -> None: + """General 🐸 TTS interface for inference. It takes a tts and a vocoder + model and synthesize speech from the provided text. + + The text is divided into a list of sentences using `pysbd` and synthesize + speech on each sentence separately. + + If you have certain special characters in your text, you need to handle + them before providing the text to Synthesizer. + + TODO: set the segmenter based on the source language + + Args: + tts_checkpoint (str, optional): path to the tts model file. + tts_config_path (str, optional): path to the tts config file. + vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None. + vocoder_config (str, optional): path to the vocoder config file. Defaults to None. + encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`, + encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`, + vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`, + vc_config (str, optional): path to the voice conversion config file. Defaults to `""`, + use_cuda (bool, optional): enable/disable cuda. Defaults to False. + """ + super().__init__() + self.tts_checkpoint = tts_checkpoint + self.tts_config_path = tts_config_path + self.tts_speakers_file = tts_speakers_file + self.tts_languages_file = tts_languages_file + self.vocoder_checkpoint = vocoder_checkpoint + self.vocoder_config = vocoder_config + self.encoder_checkpoint = encoder_checkpoint + self.encoder_config = encoder_config + self.vc_checkpoint = vc_checkpoint + self.vc_config = vc_config + self.use_cuda = use_cuda + + self.tts_model = None + self.vocoder_model = None + self.vc_model = None + self.speaker_manager = None + self.tts_speakers = {} + self.language_manager = None + self.num_languages = 0 + self.tts_languages = {} + self.d_vector_dim = 0 + self.seg = self._get_segmenter("en") + self.use_cuda = use_cuda + self.voice_dir = voice_dir + if self.use_cuda: + assert torch.cuda.is_available(), "CUDA is not availabe on this machine." + + if tts_checkpoint: + self._load_tts(tts_checkpoint, tts_config_path, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] + + if vocoder_checkpoint: + self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) + self.output_sample_rate = self.vocoder_config.audio["sample_rate"] + + if vc_checkpoint: + self._load_vc(vc_checkpoint, vc_config, use_cuda) + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] + + if model_dir: + if "fairseq" in model_dir: + self._load_fairseq_from_dir(model_dir, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] + else: + self._load_tts_from_dir(model_dir, use_cuda) + self.output_sample_rate = self.tts_config.audio["output_sample_rate"] + + @staticmethod + def _get_segmenter(lang: str): + """get the sentence segmenter for the given language. + + Args: + lang (str): target language code. + + Returns: + [type]: [description] + """ + return pysbd.Segmenter(language=lang, clean=True) + + def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None: + """Load the voice conversion model. + + 1. Load the model config. + 2. Init the model from the config. + 3. Load the model weights. + 4. Move the model to the GPU if CUDA is enabled. + + Args: + vc_checkpoint (str): path to the model checkpoint. + tts_config_path (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + # pylint: disable=global-statement + self.vc_config = load_config(vc_config_path) + self.vc_model = setup_vc_model(config=self.vc_config) + self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint) + if use_cuda: + self.vc_model.cuda() + + def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None: + """Load the fairseq model from a directory. + + We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory. + """ + self.tts_config = VitsConfig() + self.tts_model = Vits.init_from_config(self.tts_config) + self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True) + self.tts_config = self.tts_model.config + if use_cuda: + self.tts_model.cuda() + + def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: + """Load the TTS model from a directory. + + We assume the model knows how to load itself from the directory and there is a config.json file in the directory. + """ + config = load_config(os.path.join(model_dir, "config.json")) + self.tts_config = config + self.tts_model = setup_tts_model(config) + self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True) + if use_cuda: + self.tts_model.cuda() + + def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: + """Load the TTS model. + + 1. Load the model config. + 2. Init the model from the config. + 3. Load the model weights. + 4. Move the model to the GPU if CUDA is enabled. + 5. Init the speaker manager in the model. + + Args: + tts_checkpoint (str): path to the model checkpoint. + tts_config_path (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + # pylint: disable=global-statement + self.tts_config = load_config(tts_config_path) + if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None: + raise ValueError("Phonemizer is not defined in the TTS config.") + + self.tts_model = setup_tts_model(config=self.tts_config) + + if not self.encoder_checkpoint: + self._set_speaker_encoder_paths_from_tts_config() + + self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) + if use_cuda: + self.tts_model.cuda() + + if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"): + self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config, use_cuda) + + def _set_speaker_encoder_paths_from_tts_config(self): + """Set the encoder paths from the tts model config for models with speaker encoders.""" + if hasattr(self.tts_config, "model_args") and hasattr( + self.tts_config.model_args, "speaker_encoder_config_path" + ): + self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path + self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path + + def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: + """Load the vocoder model. + + 1. Load the vocoder config. + 2. Init the AudioProcessor for the vocoder. + 3. Init the vocoder model from the config. + 4. Move the model to the GPU if CUDA is enabled. + + Args: + model_file (str): path to the model checkpoint. + model_config (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + self.vocoder_config = load_config(model_config) + self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio) + self.vocoder_model = setup_vocoder_model(self.vocoder_config) + self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) + if use_cuda: + self.vocoder_model.cuda() + + def split_into_sentences(self, text) -> List[str]: + """Split give text into sentences. + + Args: + text (str): input text in string format. + + Returns: + List[str]: list of sentences. + """ + return self.seg.segment(text) + + def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None: + """Save the waveform as a file. + + Args: + wav (List[int]): waveform as a list of values. + path (str): output path to save the waveform. + pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + """ + # if tensor convert to numpy + if torch.is_tensor(wav): + wav = wav.cpu().numpy() + if isinstance(wav, list): + wav = np.array(wav) + save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out) + + def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]: + output_wav = self.vc_model.voice_conversion(source_wav, target_wav) + return output_wav + + def tts( + self, + text: str = "", + speaker_name: str = "", + language_name: str = "", + speaker_wav=None, + style_wav=None, + style_text=None, + reference_wav=None, + reference_speaker_name=None, + split_sentences: bool = True, + **kwargs, + ) -> List[int]: + """🐸 TTS magic. Run all the models and generate speech. + + Args: + text (str): input text. + speaker_name (str, optional): speaker id for multi-speaker models. Defaults to "". + language_name (str, optional): language id for multi-language models. Defaults to "". + speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None. + style_wav ([type], optional): style waveform for GST. Defaults to None. + style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None. + reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None. + reference_speaker_name ([type], optional): speaker id of reference waveform. Defaults to None. + split_sentences (bool, optional): split the input text into sentences. Defaults to True. + **kwargs: additional arguments to pass to the TTS model. + Returns: + List[int]: [description] + """ + start_time = time.time() + wavs = [] + + if not text and not reference_wav: + raise ValueError( + "You need to define either `text` (for sythesis) or a `reference_wav` (for voice conversion) to use the Coqui TTS API." + ) + + if text: + sens = [text] + if split_sentences: + print(" > Text splitted to sentences.") + sens = self.split_into_sentences(text) + print(sens) + + # handle multi-speaker + if "voice_dir" in kwargs: + self.voice_dir = kwargs["voice_dir"] + kwargs.pop("voice_dir") + speaker_embedding = None + speaker_id = None + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"): + if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts": + if self.tts_config.use_d_vector_file: + # get the average speaker embedding from the saved d_vectors. + speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding( + speaker_name, num_samples=None, randomize=False + ) + speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim] + else: + # get speaker idx from the speaker name + speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name] + # handle Neon models with single speaker. + elif len(self.tts_model.speaker_manager.name_to_id) == 1: + speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0] + elif not speaker_name and not speaker_wav: + raise ValueError( + " [!] Looks like you are using a multi-speaker model. " + "You need to define either a `speaker_idx` or a `speaker_wav` to use a multi-speaker model." + ) + else: + speaker_embedding = None + else: + if speaker_name and self.voice_dir is None: + raise ValueError( + f" [!] Missing speakers.json file path for selecting speaker {speaker_name}." + "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " + ) + + # handle multi-lingual + language_id = None + if self.tts_languages_file or ( + hasattr(self.tts_model, "language_manager") + and self.tts_model.language_manager is not None + and not self.tts_config.model == "xtts" + ): + if len(self.tts_model.language_manager.name_to_id) == 1: + language_id = list(self.tts_model.language_manager.name_to_id.values())[0] + + elif language_name and isinstance(language_name, str): + try: + language_id = self.tts_model.language_manager.name_to_id[language_name] + except KeyError as e: + raise ValueError( + f" [!] Looks like you use a multi-lingual model. " + f"Language {language_name} is not in the available languages: " + f"{self.tts_model.language_manager.name_to_id.keys()}." + ) from e + + elif not language_name: + raise ValueError( + " [!] Look like you use a multi-lingual model. " + "You need to define either a `language_name` or a `style_wav` to use a multi-lingual model." + ) + + else: + raise ValueError( + f" [!] Missing language_ids.json file path for selecting language {language_name}." + "Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. " + ) + + # compute a new d_vector from the given clip. + if ( + speaker_wav is not None + and self.tts_model.speaker_manager is not None + and hasattr(self.tts_model.speaker_manager, "encoder_ap") + and self.tts_model.speaker_manager.encoder_ap is not None + ): + speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) + + vocoder_device = "cpu" + use_gl = self.vocoder_model is None + if not use_gl: + vocoder_device = next(self.vocoder_model.parameters()).device + if self.use_cuda: + vocoder_device = "cuda" + + if not reference_wav: # not voice conversion + for sen in sens: + if hasattr(self.tts_model, "synthesize"): + outputs = self.tts_model.synthesize( + text=sen, + config=self.tts_config, + speaker_id=speaker_name, + voice_dirs=self.voice_dir, + d_vector=speaker_embedding, + speaker_wav=speaker_wav, + language=language_name, + **kwargs, + ) + else: + # synthesize voice + outputs = synthesis( + model=self.tts_model, + text=sen, + CONFIG=self.tts_config, + use_cuda=self.use_cuda, + speaker_id=speaker_id, + style_wav=style_wav, + style_text=style_text, + use_griffin_lim=use_gl, + d_vector=speaker_embedding, + language_id=language_id, + ) + waveform = outputs["wav"] + if not use_gl: + mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() + # denormalize tts output based on tts audio config + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T + # renormalize spectrogram based on vocoder config + vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) + # compute scale factor for possible sample rate mismatch + scale_factor = [ + 1, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, + ] + if scale_factor[1] != 1: + print(" > interpolating tts model output.") + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) + else: + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable + # run vocoder model + # [1, T, C] + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl: + waveform = waveform.cpu() + if not use_gl: + waveform = waveform.numpy() + waveform = waveform.squeeze() + + # trim silence + if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]: + waveform = trim_silence(waveform, self.tts_model.ap) + + wavs += list(waveform) + wavs += [0] * 10000 + else: + # get the speaker embedding or speaker id for the reference wav file + reference_speaker_embedding = None + reference_speaker_id = None + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"): + if reference_speaker_name and isinstance(reference_speaker_name, str): + if self.tts_config.use_d_vector_file: + # get the speaker embedding from the saved d_vectors. + reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name( + reference_speaker_name + )[0] + reference_speaker_embedding = np.array(reference_speaker_embedding)[ + None, : + ] # [1 x embedding_dim] + else: + # get speaker idx from the speaker name + reference_speaker_id = self.tts_model.speaker_manager.name_to_id[reference_speaker_name] + else: + reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip( + reference_wav + ) + outputs = transfer_voice( + model=self.tts_model, + CONFIG=self.tts_config, + use_cuda=self.use_cuda, + reference_wav=reference_wav, + speaker_id=speaker_id, + d_vector=speaker_embedding, + use_griffin_lim=use_gl, + reference_speaker_id=reference_speaker_id, + reference_d_vector=reference_speaker_embedding, + ) + waveform = outputs + if not use_gl: + mel_postnet_spec = outputs[0].detach().cpu().numpy() + # denormalize tts output based on tts audio config + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T + # renormalize spectrogram based on vocoder config + vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) + # compute scale factor for possible sample rate mismatch + scale_factor = [ + 1, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, + ] + if scale_factor[1] != 1: + print(" > interpolating tts model output.") + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) + else: + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable + # run vocoder model + # [1, T, C] + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"): + waveform = waveform.cpu() + if not use_gl: + waveform = waveform.numpy() + wavs = waveform.squeeze() + + # compute stats + process_time = time.time() - start_time + audio_time = len(wavs) / self.tts_config.audio["sample_rate"] + print(f" > Processing time: {process_time}") + print(f" > Real-time factor: {process_time / audio_time}") + return wavs diff --git a/TTS/utils/training.py b/TTS/utils/training.py new file mode 100644 index 0000000..b51f55e --- /dev/null +++ b/TTS/utils/training.py @@ -0,0 +1,44 @@ +import numpy as np +import torch + + +def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): + r"""Check model gradient against unexpected jumps and failures""" + skip_flag = False + if ignore_stopnet: + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_( + [param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) + else: + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) + + # compatibility with different torch versions + if isinstance(grad_norm, float): + if np.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + else: + if torch.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + return grad_norm, skip_flag + + +def gradual_training_scheduler(global_step, config): + """Setup the gradual training schedule wrt number + of active GPUs""" + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + num_gpus = 1 + new_values = None + # we set the scheduling wrt num_gpus + for values in config.gradual_training: + if global_step * num_gpus >= values[0]: + new_values = values + return new_values[1], new_values[2] diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py new file mode 100644 index 0000000..aefce2b --- /dev/null +++ b/TTS/utils/vad.py @@ -0,0 +1,88 @@ +import torch +import torchaudio + + +def read_audio(path): + wav, sr = torchaudio.load(path) + + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + + return wav.squeeze(0), sr + + +def resample_wav(wav, sr, new_sr): + wav = wav.unsqueeze(0) + transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr) + wav = transform(wav) + return wav.squeeze(0) + + +def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False): + factor = new_sr / vad_sr + new_timestamps = [] + if just_begging_end and timestamps: + # get just the start and end timestamps + new_dict = {"start": int(timestamps[0]["start"] * factor), "end": int(timestamps[-1]["end"] * factor)} + new_timestamps.append(new_dict) + else: + for ts in timestamps: + # map to the new SR + new_dict = {"start": int(ts["start"] * factor), "end": int(ts["end"] * factor)} + new_timestamps.append(new_dict) + + return new_timestamps + + +def get_vad_model_and_utils(use_cuda=False, use_onnx=False): + model, utils = torch.hub.load( + repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=use_onnx, force_onnx_cpu=True + ) + if use_cuda: + model = model.cuda() + + get_speech_timestamps, save_audio, _, _, collect_chunks = utils + return model, get_speech_timestamps, save_audio, collect_chunks + + +def remove_silence( + model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False +): + # get the VAD model and utils functions + model, get_speech_timestamps, _, collect_chunks = model_and_utils + + # read ground truth wav and resample the audio for the VAD + try: + wav, gt_sample_rate = read_audio(audio_path) + except: + print(f"> ❗ Failed to read {audio_path}") + return None, False + + # if needed, resample the audio for the VAD model + if gt_sample_rate != vad_sample_rate: + wav_vad = resample_wav(wav, gt_sample_rate, vad_sample_rate) + else: + wav_vad = wav + + if use_cuda: + wav_vad = wav_vad.cuda() + + # get speech timestamps from full audio file + speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768) + + # map the current speech_timestamps to the sample rate of the ground truth audio + new_speech_timestamps = map_timestamps_to_new_sr( + vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end + ) + + # if have speech timestamps else save the wav + if new_speech_timestamps: + wav = collect_chunks(new_speech_timestamps, wav) + is_speech = True + else: + print(f"> The file {audio_path} probably does not have speech please check it !!") + is_speech = False + + # save + torchaudio.save(out_path, wav[None, :], gt_sample_rate) + return out_path, is_speech