Add files via upload
This commit is contained in:
@@ -0,0 +1 @@
|
||||
0.22.0
|
||||
+458
@@ -0,0 +1,458 @@
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.audio.numpy_transforms import save_wav
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from TTS.config import load_config
|
||||
|
||||
|
||||
class TTS(nn.Module):
|
||||
"""TODO: Add voice conversion and Capacitron support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "",
|
||||
model_path: str = None,
|
||||
config_path: str = None,
|
||||
vocoder_path: str = None,
|
||||
vocoder_config_path: str = None,
|
||||
progress_bar: bool = True,
|
||||
gpu=False,
|
||||
):
|
||||
"""🐸TTS python interface that allows to load and use the released models.
|
||||
|
||||
Example with a multi-speaker model:
|
||||
>>> from TTS.api import TTS
|
||||
>>> tts = TTS(TTS.list_models()[0])
|
||||
>>> wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
|
||||
>>> tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
|
||||
|
||||
Example with a single-speaker model:
|
||||
>>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
|
||||
>>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
|
||||
|
||||
Example loading a model from a path:
|
||||
>>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False)
|
||||
>>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
|
||||
|
||||
Example voice cloning with YourTTS in English, French and Portuguese:
|
||||
>>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
|
||||
>>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav")
|
||||
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
|
||||
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
|
||||
|
||||
Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html):
|
||||
>>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True)
|
||||
>>> tts.tts_to_file("This is a test.", file_path="output.wav")
|
||||
|
||||
Args:
|
||||
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
|
||||
model_path (str, optional): Path to the model checkpoint. Defaults to None.
|
||||
config_path (str, optional): Path to the model config. Defaults to None.
|
||||
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
|
||||
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
|
||||
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
|
||||
self.config = load_config(config_path) if config_path else None
|
||||
self.synthesizer = None
|
||||
self.voice_converter = None
|
||||
self.model_name = ""
|
||||
if gpu:
|
||||
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
|
||||
|
||||
if model_name is not None and len(model_name) > 0:
|
||||
if "tts_models" in model_name:
|
||||
self.load_tts_model_by_name(model_name, gpu)
|
||||
elif "voice_conversion_models" in model_name:
|
||||
self.load_vc_model_by_name(model_name, gpu)
|
||||
else:
|
||||
self.load_model_by_name(model_name, gpu)
|
||||
|
||||
if model_path:
|
||||
self.load_tts_model_by_path(
|
||||
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
|
||||
)
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
return self.manager.list_tts_models()
|
||||
|
||||
@property
|
||||
def is_multi_speaker(self):
|
||||
if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager:
|
||||
return self.synthesizer.tts_model.speaker_manager.num_speakers > 1
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_multi_lingual(self):
|
||||
# Not sure what sets this to None, but applied a fix to prevent crashing.
|
||||
if (
|
||||
isinstance(self.model_name, str)
|
||||
and "xtts" in self.model_name
|
||||
or self.config
|
||||
and ("xtts" in self.config.model or len(self.config.languages) > 1)
|
||||
):
|
||||
return True
|
||||
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
|
||||
return self.synthesizer.tts_model.language_manager.num_languages > 1
|
||||
return False
|
||||
|
||||
@property
|
||||
def speakers(self):
|
||||
if not self.is_multi_speaker:
|
||||
return None
|
||||
return self.synthesizer.tts_model.speaker_manager.speaker_names
|
||||
|
||||
@property
|
||||
def languages(self):
|
||||
if not self.is_multi_lingual:
|
||||
return None
|
||||
return self.synthesizer.tts_model.language_manager.language_names
|
||||
|
||||
@staticmethod
|
||||
def get_models_file_path():
|
||||
return Path(__file__).parent / ".models.json"
|
||||
|
||||
def list_models(self):
|
||||
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False)
|
||||
|
||||
def download_model_by_name(self, model_name: str):
|
||||
model_path, config_path, model_item = self.manager.download_model(model_name)
|
||||
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
|
||||
# return model directory if there are multiple files
|
||||
# we assume that the model knows how to load itself
|
||||
return None, None, None, None, model_path
|
||||
if model_item.get("default_vocoder") is None:
|
||||
return model_path, config_path, None, None, None
|
||||
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
|
||||
return model_path, config_path, vocoder_path, vocoder_config_path, None
|
||||
|
||||
def load_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
"""Load one of the 🐸TTS models by name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name to load. You can list models by ```tts.models```.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
self.load_tts_model_by_name(model_name, gpu)
|
||||
|
||||
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
"""Load one of the voice conversion models by name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name to load. You can list models by ```tts.models```.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
model_path, config_path, _, _, _ = self.download_model_by_name(model_name)
|
||||
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
|
||||
|
||||
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
"""Load one of 🐸TTS models by name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name to load. You can list models by ```tts.models```.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
|
||||
TODO: Add tests
|
||||
"""
|
||||
self.synthesizer = None
|
||||
self.model_name = model_name
|
||||
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
||||
model_name
|
||||
)
|
||||
|
||||
# init synthesizer
|
||||
# None values are fetch from the model
|
||||
self.synthesizer = Synthesizer(
|
||||
tts_checkpoint=model_path,
|
||||
tts_config_path=config_path,
|
||||
tts_speakers_file=None,
|
||||
tts_languages_file=None,
|
||||
vocoder_checkpoint=vocoder_path,
|
||||
vocoder_config=vocoder_config_path,
|
||||
encoder_checkpoint=None,
|
||||
encoder_config=None,
|
||||
model_dir=model_dir,
|
||||
use_cuda=gpu,
|
||||
)
|
||||
|
||||
def load_tts_model_by_path(
|
||||
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
|
||||
):
|
||||
"""Load a model from a path.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the model checkpoint.
|
||||
config_path (str): Path to the model config.
|
||||
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
|
||||
vocoder_config (str, optional): Path to the vocoder config. Defaults to None.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
|
||||
self.synthesizer = Synthesizer(
|
||||
tts_checkpoint=model_path,
|
||||
tts_config_path=config_path,
|
||||
tts_speakers_file=None,
|
||||
tts_languages_file=None,
|
||||
vocoder_checkpoint=vocoder_path,
|
||||
vocoder_config=vocoder_config,
|
||||
encoder_checkpoint=None,
|
||||
encoder_config=None,
|
||||
use_cuda=gpu,
|
||||
)
|
||||
|
||||
def _check_arguments(
|
||||
self,
|
||||
speaker: str = None,
|
||||
language: str = None,
|
||||
speaker_wav: str = None,
|
||||
emotion: str = None,
|
||||
speed: float = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Check if the arguments are valid for the model."""
|
||||
# check for the coqui tts models
|
||||
if self.is_multi_speaker and (speaker is None and speaker_wav is None):
|
||||
raise ValueError("Model is multi-speaker but no `speaker` is provided.")
|
||||
if self.is_multi_lingual and language is None:
|
||||
raise ValueError("Model is multi-lingual but no `language` is provided.")
|
||||
if not self.is_multi_speaker and speaker is not None and "voice_dir" not in kwargs:
|
||||
raise ValueError("Model is not multi-speaker but `speaker` is provided.")
|
||||
if not self.is_multi_lingual and language is not None:
|
||||
raise ValueError("Model is not multi-lingual but `language` is provided.")
|
||||
if not emotion is None and not speed is None:
|
||||
raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.")
|
||||
|
||||
def tts(
|
||||
self,
|
||||
text: str,
|
||||
speaker: str = None,
|
||||
language: str = None,
|
||||
speaker_wav: str = None,
|
||||
emotion: str = None,
|
||||
speed: float = None,
|
||||
split_sentences: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Convert text to speech.
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Input text to synthesize.
|
||||
speaker (str, optional):
|
||||
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
|
||||
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
|
||||
language (str): Language of the text. If None, the default language of the speaker is used. Language is only
|
||||
supported by `XTTS` model.
|
||||
speaker_wav (str, optional):
|
||||
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||
Defaults to None.
|
||||
emotion (str, optional):
|
||||
Emotion to use for 🐸Coqui Studio models. If None, Studio models use "Neutral". Defaults to None.
|
||||
speed (float, optional):
|
||||
Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0.
|
||||
Defaults to None.
|
||||
split_sentences (bool, optional):
|
||||
Split text into sentences, synthesize them separately and concatenate the file audio.
|
||||
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
|
||||
applicable to the 🐸TTS models. Defaults to True.
|
||||
kwargs (dict, optional):
|
||||
Additional arguments for the model.
|
||||
"""
|
||||
self._check_arguments(
|
||||
speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs
|
||||
)
|
||||
wav = self.synthesizer.tts(
|
||||
text=text,
|
||||
speaker_name=speaker,
|
||||
language_name=language,
|
||||
speaker_wav=speaker_wav,
|
||||
reference_wav=None,
|
||||
style_wav=None,
|
||||
style_text=None,
|
||||
reference_speaker_name=None,
|
||||
split_sentences=split_sentences,
|
||||
**kwargs,
|
||||
)
|
||||
return wav
|
||||
|
||||
def tts_to_file(
|
||||
self,
|
||||
text: str,
|
||||
speaker: str = None,
|
||||
language: str = None,
|
||||
speaker_wav: str = None,
|
||||
emotion: str = None,
|
||||
speed: float = 1.0,
|
||||
pipe_out=None,
|
||||
file_path: str = "output.wav",
|
||||
split_sentences: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Convert text to speech.
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Input text to synthesize.
|
||||
speaker (str, optional):
|
||||
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
|
||||
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
|
||||
language (str, optional):
|
||||
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
|
||||
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
|
||||
speaker_wav (str, optional):
|
||||
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||
Defaults to None.
|
||||
emotion (str, optional):
|
||||
Emotion to use for 🐸Coqui Studio models. Defaults to "Neutral".
|
||||
speed (float, optional):
|
||||
Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None.
|
||||
pipe_out (BytesIO, optional):
|
||||
Flag to stdout the generated TTS wav file for shell pipe.
|
||||
file_path (str, optional):
|
||||
Output file path. Defaults to "output.wav".
|
||||
split_sentences (bool, optional):
|
||||
Split text into sentences, synthesize them separately and concatenate the file audio.
|
||||
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
|
||||
applicable to the 🐸TTS models. Defaults to True.
|
||||
kwargs (dict, optional):
|
||||
Additional arguments for the model.
|
||||
"""
|
||||
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
|
||||
|
||||
wav = self.tts(
|
||||
text=text,
|
||||
speaker=speaker,
|
||||
language=language,
|
||||
speaker_wav=speaker_wav,
|
||||
split_sentences=split_sentences,
|
||||
**kwargs,
|
||||
)
|
||||
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
|
||||
return file_path
|
||||
|
||||
def voice_conversion(
|
||||
self,
|
||||
source_wav: str,
|
||||
target_wav: str,
|
||||
):
|
||||
"""Voice conversion with FreeVC. Convert source wav to target speaker.
|
||||
|
||||
Args:``
|
||||
source_wav (str):
|
||||
Path to the source wav file.
|
||||
target_wav (str):`
|
||||
Path to the target wav file.
|
||||
"""
|
||||
wav = self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav)
|
||||
return wav
|
||||
|
||||
def voice_conversion_to_file(
|
||||
self,
|
||||
source_wav: str,
|
||||
target_wav: str,
|
||||
file_path: str = "output.wav",
|
||||
):
|
||||
"""Voice conversion with FreeVC. Convert source wav to target speaker.
|
||||
|
||||
Args:
|
||||
source_wav (str):
|
||||
Path to the source wav file.
|
||||
target_wav (str):
|
||||
Path to the target wav file.
|
||||
file_path (str, optional):
|
||||
Output file path. Defaults to "output.wav".
|
||||
"""
|
||||
wav = self.voice_conversion(source_wav=source_wav, target_wav=target_wav)
|
||||
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
|
||||
return file_path
|
||||
|
||||
def tts_with_vc(
|
||||
self,
|
||||
text: str,
|
||||
language: str = None,
|
||||
speaker_wav: str = None,
|
||||
speaker: str = None,
|
||||
split_sentences: bool = True,
|
||||
):
|
||||
"""Convert text to speech with voice conversion.
|
||||
|
||||
It combines tts with voice conversion to fake voice cloning.
|
||||
|
||||
- Convert text to speech with tts.
|
||||
- Convert the output wav to target speaker with voice conversion.
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Input text to synthesize.
|
||||
language (str, optional):
|
||||
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
|
||||
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
|
||||
speaker_wav (str, optional):
|
||||
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||
Defaults to None.
|
||||
speaker (str, optional):
|
||||
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
|
||||
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
|
||||
split_sentences (bool, optional):
|
||||
Split text into sentences, synthesize them separately and concatenate the file audio.
|
||||
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
|
||||
applicable to the 🐸TTS models. Defaults to True.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
||||
# Lazy code... save it to a temp file to resample it while reading it for VC
|
||||
self.tts_to_file(
|
||||
text=text, speaker=speaker, language=language, file_path=fp.name, split_sentences=split_sentences
|
||||
)
|
||||
if self.voice_converter is None:
|
||||
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
|
||||
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
|
||||
return wav
|
||||
|
||||
def tts_with_vc_to_file(
|
||||
self,
|
||||
text: str,
|
||||
language: str = None,
|
||||
speaker_wav: str = None,
|
||||
file_path: str = "output.wav",
|
||||
speaker: str = None,
|
||||
split_sentences: bool = True,
|
||||
):
|
||||
"""Convert text to speech with voice conversion and save to file.
|
||||
|
||||
Check `tts_with_vc` for more details.
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Input text to synthesize.
|
||||
language (str, optional):
|
||||
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
|
||||
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
|
||||
speaker_wav (str, optional):
|
||||
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||
Defaults to None.
|
||||
file_path (str, optional):
|
||||
Output file path. Defaults to "output.wav".
|
||||
speaker (str, optional):
|
||||
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
|
||||
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
|
||||
split_sentences (bool, optional):
|
||||
Split text into sentences, synthesize them separately and concatenate the file audio.
|
||||
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
|
||||
applicable to the 🐸TTS models. Defaults to True.
|
||||
"""
|
||||
wav = self.tts_with_vc(
|
||||
text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences
|
||||
)
|
||||
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Get detailed info about the working environment."""
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
|
||||
import json
|
||||
|
||||
import TTS
|
||||
|
||||
|
||||
def system_info():
|
||||
return {
|
||||
"OS": platform.system(),
|
||||
"architecture": platform.architecture(),
|
||||
"version": platform.version(),
|
||||
"processor": platform.processor(),
|
||||
"python": platform.python_version(),
|
||||
}
|
||||
|
||||
|
||||
def cuda_info():
|
||||
return {
|
||||
"GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
|
||||
"available": torch.cuda.is_available(),
|
||||
"version": torch.version.cuda,
|
||||
}
|
||||
|
||||
|
||||
def package_info():
|
||||
return {
|
||||
"numpy": numpy.__version__,
|
||||
"PyTorch_version": torch.__version__,
|
||||
"PyTorch_debug": torch.version.debug,
|
||||
"TTS": TTS.__version__,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
details = {"System": system_info(), "CUDA": cuda_info(), "Packages": package_info()}
|
||||
print(json.dumps(details, indent=4, sort_keys=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,165 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_checkpoint
|
||||
|
||||
if __name__ == "__main__":
|
||||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Extract attention masks from trained Tacotron/Tacotron2 models.
|
||||
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n"""
|
||||
"""Each attention mask is written to the same path as the input wav file with ".npy" file extension.
|
||||
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n"""
|
||||
"""
|
||||
Example run:
|
||||
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
||||
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth
|
||||
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
||||
--dataset_metafile metadata.csv
|
||||
--data_path /root/LJSpeech-1.1/
|
||||
--batch_size 32
|
||||
--dataset ljspeech
|
||||
--use_cuda True
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ")
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to Tacotron/Tacotron2 config file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="",
|
||||
required=True,
|
||||
help="Target dataset processor name from TTS.tts.dataset.preprocess.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_metafile",
|
||||
type=str,
|
||||
default="",
|
||||
required=True,
|
||||
help="Dataset metafile inclusing file paths with transcripts.",
|
||||
)
|
||||
parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.")
|
||||
parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.")
|
||||
|
||||
parser.add_argument(
|
||||
"--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
C = load_config(args.config_path)
|
||||
ap = AudioProcessor(**C.audio)
|
||||
|
||||
# if the vocabulary was passed, replace the default
|
||||
if "characters" in C.keys():
|
||||
symbols, phonemes = make_symbols(**C.characters)
|
||||
|
||||
# load the model
|
||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
# TODO: handle multi-speaker
|
||||
model = setup_model(C)
|
||||
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
|
||||
|
||||
# data loader
|
||||
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
|
||||
preprocessor = getattr(preprocessor, args.dataset)
|
||||
meta_data = preprocessor(args.data_path, args.dataset_metafile)
|
||||
dataset = TTSDataset(
|
||||
model.decoder.r,
|
||||
C.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
ap=ap,
|
||||
meta_data=meta_data,
|
||||
characters=C.characters if "characters" in C.keys() else None,
|
||||
add_blank=C["add_blank"] if "add_blank" in C.keys() else False,
|
||||
use_phonemes=C.use_phonemes,
|
||||
phoneme_cache_path=C.phoneme_cache_path,
|
||||
phoneme_language=C.phoneme_language,
|
||||
enable_eos_bos=C.enable_eos_bos_chars,
|
||||
)
|
||||
|
||||
dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=4,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# compute attentions
|
||||
file_paths = []
|
||||
with torch.no_grad():
|
||||
for data in tqdm(loader):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_targets = data[6]
|
||||
item_idxs = data[7]
|
||||
|
||||
# dispatch data to GPU
|
||||
if args.use_cuda:
|
||||
text_input = text_input.cuda()
|
||||
text_lengths = text_lengths.cuda()
|
||||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
|
||||
model_outputs = model.forward(text_input, text_lengths, mel_input)
|
||||
|
||||
alignments = model_outputs["alignments"].detach()
|
||||
for idx, alignment in enumerate(alignments):
|
||||
item_idx = item_idxs[idx]
|
||||
# interpolate if r > 1
|
||||
alignment = (
|
||||
torch.nn.functional.interpolate(
|
||||
alignment.transpose(0, 1).unsqueeze(0),
|
||||
size=None,
|
||||
scale_factor=model.decoder.r,
|
||||
mode="nearest",
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None,
|
||||
)
|
||||
.squeeze(0)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
# remove paddings
|
||||
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
|
||||
# set file paths
|
||||
wav_file_name = os.path.basename(item_idx)
|
||||
align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
|
||||
file_path = item_idx.replace(wav_file_name, align_file_name)
|
||||
# save output
|
||||
wav_file_abs_path = os.path.abspath(item_idx)
|
||||
file_abs_path = os.path.abspath(file_path)
|
||||
file_paths.append([wav_file_abs_path, file_abs_path])
|
||||
np.save(file_path, alignment)
|
||||
|
||||
# ourput metafile
|
||||
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
|
||||
|
||||
with open(metafile, "w", encoding="utf-8") as f:
|
||||
for p in file_paths:
|
||||
f.write(f"{p[0]}|{p[1]}\n")
|
||||
print(f" >> Metafile created: {metafile}")
|
||||
@@ -0,0 +1,197 @@
|
||||
import argparse
|
||||
import os
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.managers import save_file
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
model_path,
|
||||
config_path,
|
||||
output_path,
|
||||
old_speakers_file=None,
|
||||
old_append=False,
|
||||
config_dataset_path=None,
|
||||
formatter_name=None,
|
||||
dataset_name=None,
|
||||
dataset_path=None,
|
||||
meta_file_train=None,
|
||||
meta_file_val=None,
|
||||
disable_cuda=False,
|
||||
no_eval=False,
|
||||
):
|
||||
use_cuda = torch.cuda.is_available() and not disable_cuda
|
||||
|
||||
if config_dataset_path is not None:
|
||||
c_dataset = load_config(config_dataset_path)
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not no_eval)
|
||||
else:
|
||||
c_dataset = BaseDatasetConfig()
|
||||
c_dataset.formatter = formatter_name
|
||||
c_dataset.dataset_name = dataset_name
|
||||
c_dataset.path = dataset_path
|
||||
if meta_file_train is not None:
|
||||
c_dataset.meta_file_train = meta_file_train
|
||||
if meta_file_val is not None:
|
||||
c_dataset.meta_file_val = meta_file_val
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not no_eval)
|
||||
|
||||
if meta_data_eval is None:
|
||||
samples = meta_data_train
|
||||
else:
|
||||
samples = meta_data_train + meta_data_eval
|
||||
|
||||
encoder_manager = SpeakerManager(
|
||||
encoder_model_path=model_path,
|
||||
encoder_config_path=config_path,
|
||||
d_vectors_file_path=old_speakers_file,
|
||||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
class_name_key = encoder_manager.encoder_config.class_name_key
|
||||
|
||||
# compute speaker embeddings
|
||||
if old_speakers_file is not None and old_append:
|
||||
speaker_mapping = encoder_manager.embeddings
|
||||
else:
|
||||
speaker_mapping = {}
|
||||
|
||||
for fields in tqdm(samples):
|
||||
class_name = fields[class_name_key]
|
||||
audio_file = fields["audio_file"]
|
||||
embedding_key = fields["audio_unique_name"]
|
||||
|
||||
# Only update the speaker name when the embedding is already in the old file.
|
||||
if embedding_key in speaker_mapping:
|
||||
speaker_mapping[embedding_key]["name"] = class_name
|
||||
continue
|
||||
|
||||
if old_speakers_file is not None and embedding_key in encoder_manager.clip_ids:
|
||||
# get the embedding from the old file
|
||||
embedd = encoder_manager.get_embedding_by_clip(embedding_key)
|
||||
else:
|
||||
# extract the embedding
|
||||
embedd = encoder_manager.compute_embedding_from_clip(audio_file)
|
||||
|
||||
# create speaker_mapping if target dataset is defined
|
||||
speaker_mapping[embedding_key] = {}
|
||||
speaker_mapping[embedding_key]["name"] = class_name
|
||||
speaker_mapping[embedding_key]["embedding"] = embedd
|
||||
|
||||
if speaker_mapping:
|
||||
# save speaker_mapping if target dataset is defined
|
||||
if os.path.isdir(output_path):
|
||||
mapping_file_path = os.path.join(output_path, "speakers.pth")
|
||||
else:
|
||||
mapping_file_path = output_path
|
||||
|
||||
if os.path.dirname(mapping_file_path) != "":
|
||||
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
|
||||
|
||||
save_file(speaker_mapping, mapping_file_path)
|
||||
print("Speaker embeddings saved at:", mapping_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --config_dataset_path dataset_config.json
|
||||
|
||||
python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --formatter_name coqui --dataset_path /path/to/vctk/dataset --dataset_name my_vctk --meta_file_train /path/to/vctk/metafile_train.csv --meta_file_val /path/to/vctk/metafile_eval.csv
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
help="Path to model checkpoint file. It defaults to the released speaker encoder.",
|
||||
default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
help="Path to model config file. It defaults to the released speaker encoder config.",
|
||||
default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_dataset_path",
|
||||
type=str,
|
||||
help="Path to dataset config file. You either need to provide this or `formatter_name`, `dataset_name` and `dataset_path` arguments.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
help="Path for output `pth` or `json` file.",
|
||||
default="speakers.pth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_file",
|
||||
type=str,
|
||||
help="The old existing embedding file, from which the embeddings will be directly loaded for already computed audio clips.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_append",
|
||||
help="Append new audio clip embeddings to the old embedding file, generate a new non-duplicated merged embedding file. Default False",
|
||||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False)
|
||||
parser.add_argument("--no_eval", help="Do not compute eval?. Default False", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--formatter_name",
|
||||
type=str,
|
||||
help="Name of the formatter to use. You either need to provide this or `config_dataset_path`",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
help="Name of the dataset to use. You either need to provide this or `config_dataset_path`",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_path",
|
||||
type=str,
|
||||
help="Path to the dataset. You either need to provide this or `config_dataset_path`",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta_file_train",
|
||||
type=str,
|
||||
help="Path to the train meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--meta_file_val",
|
||||
type=str,
|
||||
help="Path to the evaluation meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
compute_embeddings(
|
||||
args.model_path,
|
||||
args.config_path,
|
||||
args.output_path,
|
||||
old_speakers_file=args.old_file,
|
||||
old_append=args.old_append,
|
||||
config_dataset_path=args.config_dataset_path,
|
||||
formatter_name=args.formatter_name,
|
||||
dataset_name=args.dataset_name,
|
||||
dataset_path=args.dataset_path,
|
||||
meta_file_train=args.meta_file_train,
|
||||
meta_file_val=args.meta_file_val,
|
||||
disable_cuda=args.disable_cuda,
|
||||
no_eval=args.no_eval,
|
||||
)
|
||||
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
# from TTS.utils.io import load_config
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
|
||||
parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.")
|
||||
parser.add_argument("out_path", type=str, help="save path (directory and filename).")
|
||||
parser.add_argument(
|
||||
"--data_path",
|
||||
type=str,
|
||||
required=False,
|
||||
help="folder including the target set of wavs overriding dataset config.",
|
||||
)
|
||||
args, overrides = parser.parse_known_args()
|
||||
|
||||
CONFIG = load_config(args.config_path)
|
||||
CONFIG.parse_known_args(overrides, relaxed_parser=True)
|
||||
|
||||
# load config
|
||||
CONFIG.audio.signal_norm = False # do not apply earlier normalization
|
||||
CONFIG.audio.stats_path = None # discard pre-defined stats
|
||||
|
||||
# load audio processor
|
||||
ap = AudioProcessor(**CONFIG.audio.to_dict())
|
||||
|
||||
# load the meta data of target dataset
|
||||
if args.data_path:
|
||||
dataset_items = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True)
|
||||
else:
|
||||
dataset_items = load_tts_samples(CONFIG.datasets)[0] # take only train data
|
||||
print(f" > There are {len(dataset_items)} files.")
|
||||
|
||||
mel_sum = 0
|
||||
mel_square_sum = 0
|
||||
linear_sum = 0
|
||||
linear_square_sum = 0
|
||||
N = 0
|
||||
for item in tqdm(dataset_items):
|
||||
# compute features
|
||||
wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"])
|
||||
linear = ap.spectrogram(wav)
|
||||
mel = ap.melspectrogram(wav)
|
||||
|
||||
# compute stats
|
||||
N += mel.shape[1]
|
||||
mel_sum += mel.sum(1)
|
||||
linear_sum += linear.sum(1)
|
||||
mel_square_sum += (mel**2).sum(axis=1)
|
||||
linear_square_sum += (linear**2).sum(axis=1)
|
||||
|
||||
mel_mean = mel_sum / N
|
||||
mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2)
|
||||
linear_mean = linear_sum / N
|
||||
linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2)
|
||||
|
||||
output_file_path = args.out_path
|
||||
stats = {}
|
||||
stats["mel_mean"] = mel_mean
|
||||
stats["mel_std"] = mel_scale
|
||||
stats["linear_mean"] = linear_mean
|
||||
stats["linear_std"] = linear_scale
|
||||
|
||||
print(f" > Avg mel spec mean: {mel_mean.mean()}")
|
||||
print(f" > Avg mel spec scale: {mel_scale.mean()}")
|
||||
print(f" > Avg linear spec mean: {linear_mean.mean()}")
|
||||
print(f" > Avg linear spec scale: {linear_scale.mean()}")
|
||||
|
||||
# set default config values for mean-var scaling
|
||||
CONFIG.audio.stats_path = output_file_path
|
||||
CONFIG.audio.signal_norm = True
|
||||
# remove redundant values
|
||||
del CONFIG.audio.max_norm
|
||||
del CONFIG.audio.min_level_db
|
||||
del CONFIG.audio.symmetric_norm
|
||||
del CONFIG.audio.clip_norm
|
||||
stats["audio_config"] = CONFIG.audio.to_dict()
|
||||
np.save(output_file_path, stats, allow_pickle=True)
|
||||
print(f" > stats saved to {output_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,88 @@
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
|
||||
def compute_encoder_accuracy(dataset_items, encoder_manager):
|
||||
class_name_key = encoder_manager.encoder_config.class_name_key
|
||||
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
|
||||
|
||||
class_acc_dict = {}
|
||||
|
||||
# compute embeddings for all wav_files
|
||||
for item in tqdm(dataset_items):
|
||||
class_name = item[class_name_key]
|
||||
wav_file = item["audio_file"]
|
||||
|
||||
# extract the embedding
|
||||
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
|
||||
if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None:
|
||||
embedding = torch.FloatTensor(embedd).unsqueeze(0)
|
||||
if encoder_manager.use_cuda:
|
||||
embedding = embedding.cuda()
|
||||
|
||||
class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item()
|
||||
predicted_label = map_classid_to_classname[str(class_id)]
|
||||
else:
|
||||
predicted_label = None
|
||||
|
||||
if class_name is not None and predicted_label is not None:
|
||||
is_equal = int(class_name == predicted_label)
|
||||
if class_name not in class_acc_dict:
|
||||
class_acc_dict[class_name] = [is_equal]
|
||||
else:
|
||||
class_acc_dict[class_name].append(is_equal)
|
||||
else:
|
||||
raise RuntimeError("Error: class_name or/and predicted_label are None")
|
||||
|
||||
acc_avg = 0
|
||||
for key, values in class_acc_dict.items():
|
||||
acc = sum(values) / len(values)
|
||||
print("Class", key, "Accuracy:", acc)
|
||||
acc_avg += acc
|
||||
|
||||
print("Average Accuracy:", acc_avg / len(class_acc_dict))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Compute the accuracy of the encoder.\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
python TTS/bin/eval_encoder.py emotion_encoder_model.pth emotion_encoder_config.json dataset_config.json
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
|
||||
parser.add_argument(
|
||||
"config_path",
|
||||
type=str,
|
||||
help="Path to model config file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"config_dataset_path",
|
||||
type=str,
|
||||
help="Path to dataset config file.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
c_dataset = load_config(args.config_dataset_path)
|
||||
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
|
||||
items = meta_data_train + meta_data_eval
|
||||
|
||||
enc_manager = SpeakerManager(
|
||||
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
|
||||
)
|
||||
|
||||
compute_encoder_accuracy(items, enc_manager)
|
||||
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Extract Mel spectrograms with teacher forcing."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import quantize
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
||||
def setup_loader(ap, r, verbose=False):
|
||||
tokenizer, _ = TTSTokenizer.init_from_config(c)
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=r,
|
||||
compute_linear_spec=False,
|
||||
samples=meta_data,
|
||||
tokenizer=tokenizer,
|
||||
ap=ap,
|
||||
batch_group_size=0,
|
||||
min_text_len=c.min_text_len,
|
||||
max_text_len=c.max_text_len,
|
||||
min_audio_len=c.min_audio_len,
|
||||
max_audio_len=c.max_audio_len,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
precompute_num_workers=0,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None,
|
||||
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(c.num_loader_workers)
|
||||
dataset.preprocess_samples()
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=None,
|
||||
num_workers=c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def set_filename(wav_path, out_path):
|
||||
wav_file = os.path.basename(wav_path)
|
||||
file_name = wav_file.split(".")[0]
|
||||
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
|
||||
wavq_path = os.path.join(out_path, "quant", file_name)
|
||||
mel_path = os.path.join(out_path, "mel", file_name)
|
||||
wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav")
|
||||
wav_path = os.path.join(out_path, "wav", file_name + ".wav")
|
||||
return file_name, wavq_path, mel_path, wav_gl_path, wav_path
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data["token_id"]
|
||||
text_lengths = data["token_id_lengths"]
|
||||
mel_input = data["mel"]
|
||||
mel_lengths = data["mel_lengths"]
|
||||
item_idx = data["item_idxs"]
|
||||
d_vectors = data["d_vectors"]
|
||||
speaker_ids = data["speaker_ids"]
|
||||
attn_mask = data["attns"]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
if d_vectors is not None:
|
||||
d_vectors = d_vectors.cuda(non_blocking=True)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||
return (
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids,
|
||||
d_vectors,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
attn_mask,
|
||||
item_idx,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
model_name,
|
||||
model,
|
||||
ap,
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids=None,
|
||||
d_vectors=None,
|
||||
):
|
||||
if model_name == "glow_tts":
|
||||
speaker_c = None
|
||||
if speaker_ids is not None:
|
||||
speaker_c = speaker_ids
|
||||
elif d_vectors is not None:
|
||||
speaker_c = d_vectors
|
||||
outputs = model.inference_with_MAS(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
|
||||
)
|
||||
model_output = outputs["model_outputs"]
|
||||
model_output = model_output.detach().cpu().numpy()
|
||||
|
||||
elif "tacotron" in model_name:
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
# normalize tacotron output
|
||||
if model_name == "tacotron":
|
||||
mel_specs = []
|
||||
postnet_outputs = postnet_outputs.data.cpu().numpy()
|
||||
for b in range(postnet_outputs.shape[0]):
|
||||
postnet_output = postnet_outputs[b]
|
||||
mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T))
|
||||
model_output = torch.stack(mel_specs).cpu().numpy()
|
||||
|
||||
elif model_name == "tacotron2":
|
||||
model_output = postnet_outputs.detach().cpu().numpy()
|
||||
return model_output
|
||||
|
||||
|
||||
def extract_spectrograms(
|
||||
data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt"
|
||||
):
|
||||
model.eval()
|
||||
export_metadata = []
|
||||
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
|
||||
# format data
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids,
|
||||
d_vectors,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
item_idx,
|
||||
) = format_data(data)
|
||||
|
||||
model_output = inference(
|
||||
c.model.lower(),
|
||||
model,
|
||||
ap,
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids,
|
||||
d_vectors,
|
||||
)
|
||||
|
||||
for idx in range(text_input.shape[0]):
|
||||
wav_file_path = item_idx[idx]
|
||||
wav = ap.load_wav(wav_file_path)
|
||||
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
|
||||
|
||||
# quantize and save wav
|
||||
if quantize_bits > 0:
|
||||
wavq = quantize(wav, quantize_bits)
|
||||
np.save(wavq_path, wavq)
|
||||
|
||||
# save TTS mel
|
||||
mel = model_output[idx]
|
||||
mel_length = mel_lengths[idx]
|
||||
mel = mel[:mel_length, :].T
|
||||
np.save(mel_path, mel)
|
||||
|
||||
export_metadata.append([wav_file_path, mel_path])
|
||||
if save_audio:
|
||||
ap.save_wav(wav, wav_path)
|
||||
|
||||
if debug:
|
||||
print("Audio for debug saved at:", wav_gl_path)
|
||||
wav = ap.inv_melspectrogram(mel)
|
||||
ap.save_wav(wav, wav_gl_path)
|
||||
|
||||
with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f:
|
||||
for data in export_metadata:
|
||||
f.write(f"{data[0]}|{data[1]+'.npy'}\n")
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data, speaker_manager
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_tts_samples(
|
||||
c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
|
||||
)
|
||||
|
||||
# use eval and training partitions
|
||||
meta_data = meta_data_train + meta_data_eval
|
||||
|
||||
# init speaker manager
|
||||
if c.use_speaker_embedding:
|
||||
speaker_manager = SpeakerManager(data_items=meta_data)
|
||||
elif c.use_d_vector_file:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file)
|
||||
else:
|
||||
speaker_manager = None
|
||||
|
||||
# setup model
|
||||
model = setup_model(c)
|
||||
|
||||
# restore model
|
||||
model.load_checkpoint(c, args.checkpoint_path, eval=True)
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
# set r
|
||||
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
|
||||
own_loader = setup_loader(ap, r, verbose=True)
|
||||
|
||||
extract_spectrograms(
|
||||
own_loader,
|
||||
model,
|
||||
ap,
|
||||
args.output_path,
|
||||
quantize_bits=args.quantize_bits,
|
||||
save_audio=args.save_audio,
|
||||
debug=args.debug,
|
||||
metada_name="metada.txt",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
|
||||
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
|
||||
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
||||
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
||||
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
||||
parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
c.audio.trim_silence = False
|
||||
main(args)
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Find all the unique characters in a dataset"""
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
|
||||
|
||||
def main():
|
||||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
|
||||
python TTS/bin/find_unique_chars.py --config_path config.json
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
|
||||
# load all datasets
|
||||
train_items, eval_items = load_tts_samples(
|
||||
c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
|
||||
)
|
||||
|
||||
items = train_items + eval_items
|
||||
|
||||
texts = "".join(item["text"] for item in items)
|
||||
chars = set(texts)
|
||||
lower_chars = filter(lambda c: c.islower(), chars)
|
||||
chars_force_lower = [c.lower() for c in chars]
|
||||
chars_force_lower = set(chars_force_lower)
|
||||
|
||||
print(f" > Number of unique characters: {len(chars)}")
|
||||
print(f" > Unique characters: {''.join(sorted(chars))}")
|
||||
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
|
||||
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Find all the unique characters in a dataset"""
|
||||
import argparse
|
||||
import multiprocessing
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.text.phonemizers import Gruut
|
||||
|
||||
|
||||
def compute_phonemes(item):
|
||||
text = item["text"]
|
||||
ph = phonemizer.phonemize(text).replace("|", "")
|
||||
return set(list(ph))
|
||||
|
||||
|
||||
def main():
|
||||
# pylint: disable=W0601
|
||||
global c, phonemizer
|
||||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
|
||||
python TTS/bin/find_unique_phonemes.py --config_path config.json
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
|
||||
# load all datasets
|
||||
train_items, eval_items = load_tts_samples(
|
||||
c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
|
||||
)
|
||||
items = train_items + eval_items
|
||||
print("Num items:", len(items))
|
||||
|
||||
language_list = [item["language"] for item in items]
|
||||
is_lang_def = all(language_list)
|
||||
|
||||
if not c.phoneme_language or not is_lang_def:
|
||||
raise ValueError("Phoneme language must be defined in config.")
|
||||
|
||||
if not language_list.count(language_list[0]) == len(language_list):
|
||||
raise ValueError(
|
||||
"Currently, just one phoneme language per config file is supported !! Please split the dataset config into different configs and run it individually for each language !!"
|
||||
)
|
||||
|
||||
phonemizer = Gruut(language=language_list[0], keep_puncs=True)
|
||||
|
||||
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
|
||||
phones = []
|
||||
for ph in phonemes:
|
||||
phones.extend(ph)
|
||||
|
||||
phones = set(phones)
|
||||
lower_phones = filter(lambda c: c.islower(), phones)
|
||||
phones_force_lower = [c.lower() for c in phones]
|
||||
phones_force_lower = set(phones_force_lower)
|
||||
|
||||
print(f" > Number of unique phonemes: {len(phones)}")
|
||||
print(f" > Unique phonemes: {''.join(sorted(phones))}")
|
||||
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
|
||||
print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,124 @@
|
||||
import argparse
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.vad import get_vad_model_and_utils, remove_silence
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
def adjust_path_and_remove_silence(audio_path):
|
||||
output_path = audio_path.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
||||
# ignore if the file exists
|
||||
if os.path.exists(output_path) and not args.force:
|
||||
return output_path, False
|
||||
|
||||
# create all directory structure
|
||||
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
# remove the silence and save the audio
|
||||
output_path, is_speech = remove_silence(
|
||||
model_and_utils,
|
||||
audio_path,
|
||||
output_path,
|
||||
trim_just_beginning_and_end=args.trim_just_beginning_and_end,
|
||||
use_cuda=args.use_cuda,
|
||||
)
|
||||
return output_path, is_speech
|
||||
|
||||
|
||||
def preprocess_audios():
|
||||
files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
|
||||
print("> Number of files: ", len(files))
|
||||
if not args.force:
|
||||
print("> Ignoring files that already exist in the output idrectory.")
|
||||
|
||||
if args.trim_just_beginning_and_end:
|
||||
print("> Trimming just the beginning and the end with nonspeech parts.")
|
||||
else:
|
||||
print("> Trimming all nonspeech parts.")
|
||||
|
||||
filtered_files = []
|
||||
if files:
|
||||
# create threads
|
||||
# num_threads = multiprocessing.cpu_count()
|
||||
# process_map(adjust_path_and_remove_silence, files, max_workers=num_threads, chunksize=15)
|
||||
|
||||
if args.num_processes > 1:
|
||||
with multiprocessing.Pool(processes=args.num_processes) as pool:
|
||||
results = list(
|
||||
tqdm(
|
||||
pool.imap_unordered(adjust_path_and_remove_silence, files),
|
||||
total=len(files),
|
||||
desc="Processing audio files",
|
||||
)
|
||||
)
|
||||
for output_path, is_speech in results:
|
||||
if not is_speech:
|
||||
filtered_files.append(output_path)
|
||||
else:
|
||||
for f in tqdm(files):
|
||||
output_path, is_speech = adjust_path_and_remove_silence(f)
|
||||
if not is_speech:
|
||||
filtered_files.append(output_path)
|
||||
|
||||
# write files that do not have speech
|
||||
with open(os.path.join(args.output_dir, "filtered_files.txt"), "w", encoding="utf-8") as f:
|
||||
for file in filtered_files:
|
||||
f.write(str(file) + "\n")
|
||||
else:
|
||||
print("> No files Found !")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True"
|
||||
)
|
||||
parser.add_argument("-i", "--input_dir", type=str, help="Dataset root dir", required=True)
|
||||
parser.add_argument("-o", "--output_dir", type=str, help="Output Dataset dir", default="")
|
||||
parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--glob",
|
||||
type=str,
|
||||
default="**/*.wav",
|
||||
help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--trim_just_beginning_and_end",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trim. Default True",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--use_cuda",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If True use cuda",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_onnx",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If True use onnx",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_processes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of processes to use",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_dir == "":
|
||||
args.output_dir = args.input_dir
|
||||
|
||||
# load the model and utils
|
||||
model_and_utils = get_vad_model_and_utils(use_cuda=args.use_cuda, use_onnx=args.use_onnx)
|
||||
preprocess_audios()
|
||||
@@ -0,0 +1,90 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
from argparse import RawTextHelpFormatter
|
||||
from multiprocessing import Pool
|
||||
from shutil import copytree
|
||||
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def resample_file(func_args):
|
||||
filename, output_sr = func_args
|
||||
y, sr = librosa.load(filename, sr=output_sr)
|
||||
sf.write(filename, y, sr)
|
||||
|
||||
|
||||
def resample_files(input_dir, output_sr, output_dir=None, file_ext="wav", n_jobs=10):
|
||||
if output_dir:
|
||||
print("Recursively copying the input folder...")
|
||||
copytree(input_dir, output_dir)
|
||||
input_dir = output_dir
|
||||
|
||||
print("Resampling the audio files...")
|
||||
audio_files = glob.glob(os.path.join(input_dir, f"**/*.{file_ext}"), recursive=True)
|
||||
print(f"Found {len(audio_files)} files...")
|
||||
audio_files = list(zip(audio_files, len(audio_files) * [output_sr]))
|
||||
with Pool(processes=n_jobs) as p:
|
||||
with tqdm(total=len(audio_files)) as pbar:
|
||||
for _, _ in enumerate(p.imap_unordered(resample_file, audio_files)):
|
||||
pbar.update()
|
||||
|
||||
print("Done !")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Resample a folder recusively with librosa
|
||||
Can be used in place or create a copy of the folder as an output.\n\n
|
||||
Example run:
|
||||
python TTS/bin/resample.py
|
||||
--input_dir /root/LJSpeech-1.1/
|
||||
--output_sr 22050
|
||||
--output_dir /root/resampled_LJSpeech-1.1/
|
||||
--file_ext wav
|
||||
--n_jobs 24
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path of the folder containing the audio files to resample",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_sr",
|
||||
type=int,
|
||||
default=22050,
|
||||
required=False,
|
||||
help="Samlple rate to which the audio files should be resampled",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path of the destination folder. If not defined, the operation is done in place",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--file_ext",
|
||||
type=str,
|
||||
default="wav",
|
||||
required=False,
|
||||
help="Extension of the audio files to resample",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
resample_files(args.input_dir, args.output_sr, args.output_dir, args.file_ext, args.n_jobs)
|
||||
@@ -0,0 +1,494 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import sys
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
from pathlib import Path
|
||||
|
||||
description = """
|
||||
Synthesize speech on command line.
|
||||
|
||||
You can either use your trained model or choose a model from the provided list.
|
||||
|
||||
If you don't specify any models, then it uses LJSpeech based English model.
|
||||
|
||||
#### Single Speaker Models
|
||||
|
||||
- List provided models:
|
||||
|
||||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
|
||||
- Get model info (for both tts_models and vocoder_models):
|
||||
|
||||
- Query by type/name:
|
||||
The model_info_by_name uses the name as it from the --list_models.
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
For example:
|
||||
```
|
||||
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
|
||||
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
|
||||
```
|
||||
- Query by type/idx:
|
||||
The model_query_idx uses the corresponding idx from --list_models.
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx tts_models/3
|
||||
```
|
||||
|
||||
- Query info for model info by full name:
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
|
||||
- Run TTS with default models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run TTS and pipe out the generated TTS wav file data:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --pipe_out --out_path output/path/speech.wav | aplay
|
||||
```
|
||||
|
||||
- Run a TTS model with its default vocoder model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run with specific TTS and vocoder models from the list:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run your own TTS and Vocoder models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
#### Multi-speaker Models
|
||||
|
||||
- List the available speakers and choose a <speaker_id> among them:
|
||||
|
||||
```
|
||||
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
|
||||
```
|
||||
|
||||
- Run the multi-speaker TTS model with the target speaker ID:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
- Run your own multi-speaker TTS model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
### Voice Conversion Models
|
||||
|
||||
```
|
||||
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
if v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description.replace(" ```\n", ""),
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--list_models",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="list available pre-trained TTS and vocoder models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_info_by_idx",
|
||||
type=str,
|
||||
default=None,
|
||||
help="model info using query format: <model_type>/<model_query_idx>",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_info_by_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="model info using query format: <model_type>/<language>/<dataset>/<model_name>",
|
||||
)
|
||||
|
||||
parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")
|
||||
|
||||
# Args for running pre-trained TTS models.
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="tts_models/en/ljspeech/tacotron2-DDC",
|
||||
help="Name of one of the pre-trained TTS models in format <language>/<dataset>/<model_name>",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocoder_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of one of the pre-trained vocoder models in format <language>/<dataset>/<model_name>",
|
||||
)
|
||||
|
||||
# Args for running custom models
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to model file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_path",
|
||||
type=str,
|
||||
default="tts_output.wav",
|
||||
help="Output wav file path.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False)
|
||||
parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu")
|
||||
parser.add_argument(
|
||||
"--vocoder_path",
|
||||
type=str,
|
||||
help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None)
|
||||
parser.add_argument(
|
||||
"--encoder_path",
|
||||
type=str,
|
||||
help="Path to speaker encoder model file.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--encoder_config_path", type=str, help="Path to speaker encoder config file.", default=None)
|
||||
parser.add_argument(
|
||||
"--pipe_out",
|
||||
help="stdout the generated TTS wav file for shell pipe.",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
)
|
||||
|
||||
# args for multi-speaker synthesis
|
||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
|
||||
parser.add_argument(
|
||||
"--speaker_idx",
|
||||
type=str,
|
||||
help="Target speaker ID for a multi-speaker TTS model.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language_idx",
|
||||
type=str,
|
||||
help="Target language ID for a multi-lingual TTS model.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speaker_wav",
|
||||
nargs="+",
|
||||
help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None)
|
||||
parser.add_argument(
|
||||
"--capacitron_style_wav", type=str, help="Wav path file for Capacitron prosody reference.", default=None
|
||||
)
|
||||
parser.add_argument("--capacitron_style_text", type=str, help="Transcription of the reference.", default=None)
|
||||
parser.add_argument(
|
||||
"--list_speaker_idxs",
|
||||
help="List available speaker ids for the defined multi-speaker model.",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_language_idxs",
|
||||
help="List available language ids for the defined multi-lingual model.",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
)
|
||||
# aux args
|
||||
parser.add_argument(
|
||||
"--save_spectogram",
|
||||
type=bool,
|
||||
help="If true save raw spectogram for further (vocoder) processing in out_path.",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference_wav",
|
||||
type=str,
|
||||
help="Reference wav file to convert in the voice of the speaker_idx or speaker_wav",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference_speaker_idx",
|
||||
type=str,
|
||||
help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress_bar",
|
||||
type=str2bool,
|
||||
help="If true shows a progress bar for the model download. Defaults to True",
|
||||
default=True,
|
||||
)
|
||||
|
||||
# voice conversion args
|
||||
parser.add_argument(
|
||||
"--source_wav",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Original audio file to convert in the voice of the target_wav",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_wav",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Target audio file to convert in the voice of the source_wav",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--voice_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Voice dir for tortoise model",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# print the description if either text or list_models is not set
|
||||
check_args = [
|
||||
args.text,
|
||||
args.list_models,
|
||||
args.list_speaker_idxs,
|
||||
args.list_language_idxs,
|
||||
args.reference_wav,
|
||||
args.model_info_by_idx,
|
||||
args.model_info_by_name,
|
||||
args.source_wav,
|
||||
args.target_wav,
|
||||
]
|
||||
if not any(check_args):
|
||||
parser.parse_args(["-h"])
|
||||
|
||||
pipe_out = sys.stdout if args.pipe_out else None
|
||||
|
||||
with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout):
|
||||
# Late-import to make things load faster
|
||||
from TTS.api import TTS
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
# load model manager
|
||||
path = Path(__file__).parent / "../.models.json"
|
||||
manager = ModelManager(path, progress_bar=args.progress_bar)
|
||||
api = TTS()
|
||||
|
||||
tts_path = None
|
||||
tts_config_path = None
|
||||
speakers_file_path = None
|
||||
language_ids_file_path = None
|
||||
vocoder_path = None
|
||||
vocoder_config_path = None
|
||||
encoder_path = None
|
||||
encoder_config_path = None
|
||||
vc_path = None
|
||||
vc_config_path = None
|
||||
model_dir = None
|
||||
|
||||
# CASE1 #list : list pre-trained TTS models
|
||||
if args.list_models:
|
||||
manager.list_models()
|
||||
sys.exit()
|
||||
|
||||
# CASE2 #info : model info for pre-trained TTS models
|
||||
if args.model_info_by_idx:
|
||||
model_query = args.model_info_by_idx
|
||||
manager.model_info_by_idx(model_query)
|
||||
sys.exit()
|
||||
|
||||
if args.model_info_by_name:
|
||||
model_query_full_name = args.model_info_by_name
|
||||
manager.model_info_by_full_name(model_query_full_name)
|
||||
sys.exit()
|
||||
|
||||
# CASE3: load pre-trained model paths
|
||||
if args.model_name is not None and not args.model_path:
|
||||
model_path, config_path, model_item = manager.download_model(args.model_name)
|
||||
# tts model
|
||||
if model_item["model_type"] == "tts_models":
|
||||
tts_path = model_path
|
||||
tts_config_path = config_path
|
||||
if "default_vocoder" in model_item:
|
||||
args.vocoder_name = (
|
||||
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
)
|
||||
|
||||
# voice conversion model
|
||||
if model_item["model_type"] == "voice_conversion_models":
|
||||
vc_path = model_path
|
||||
vc_config_path = config_path
|
||||
|
||||
# tts model with multiple files to be loaded from the directory path
|
||||
if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
|
||||
model_dir = model_path
|
||||
tts_path = None
|
||||
tts_config_path = None
|
||||
args.vocoder_name = None
|
||||
|
||||
# load vocoder
|
||||
if args.vocoder_name is not None and not args.vocoder_path:
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||
|
||||
# CASE4: set custom model paths
|
||||
if args.model_path is not None:
|
||||
tts_path = args.model_path
|
||||
tts_config_path = args.config_path
|
||||
speakers_file_path = args.speakers_file_path
|
||||
language_ids_file_path = args.language_ids_file_path
|
||||
|
||||
if args.vocoder_path is not None:
|
||||
vocoder_path = args.vocoder_path
|
||||
vocoder_config_path = args.vocoder_config_path
|
||||
|
||||
if args.encoder_path is not None:
|
||||
encoder_path = args.encoder_path
|
||||
encoder_config_path = args.encoder_config_path
|
||||
|
||||
device = args.device
|
||||
if args.use_cuda:
|
||||
device = "cuda"
|
||||
|
||||
# load models
|
||||
synthesizer = Synthesizer(
|
||||
tts_path,
|
||||
tts_config_path,
|
||||
speakers_file_path,
|
||||
language_ids_file_path,
|
||||
vocoder_path,
|
||||
vocoder_config_path,
|
||||
encoder_path,
|
||||
encoder_config_path,
|
||||
vc_path,
|
||||
vc_config_path,
|
||||
model_dir,
|
||||
args.voice_dir,
|
||||
).to(device)
|
||||
|
||||
# query speaker ids of a multi-speaker model.
|
||||
if args.list_speaker_idxs:
|
||||
print(
|
||||
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
||||
)
|
||||
print(synthesizer.tts_model.speaker_manager.name_to_id)
|
||||
return
|
||||
|
||||
# query langauge ids of a multi-lingual model.
|
||||
if args.list_language_idxs:
|
||||
print(
|
||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
)
|
||||
print(synthesizer.tts_model.language_manager.name_to_id)
|
||||
return
|
||||
|
||||
# check the arguments against a multi-speaker model.
|
||||
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
||||
print(
|
||||
" [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to "
|
||||
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
|
||||
)
|
||||
return
|
||||
|
||||
# RUN THE SYNTHESIS
|
||||
if args.text:
|
||||
print(" > Text: {}".format(args.text))
|
||||
|
||||
# kick it
|
||||
if tts_path is not None:
|
||||
wav = synthesizer.tts(
|
||||
args.text,
|
||||
speaker_name=args.speaker_idx,
|
||||
language_name=args.language_idx,
|
||||
speaker_wav=args.speaker_wav,
|
||||
reference_wav=args.reference_wav,
|
||||
style_wav=args.capacitron_style_wav,
|
||||
style_text=args.capacitron_style_text,
|
||||
reference_speaker_name=args.reference_speaker_idx,
|
||||
)
|
||||
elif vc_path is not None:
|
||||
wav = synthesizer.voice_conversion(
|
||||
source_wav=args.source_wav,
|
||||
target_wav=args.target_wav,
|
||||
)
|
||||
elif model_dir is not None:
|
||||
wav = synthesizer.tts(
|
||||
args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
|
||||
)
|
||||
|
||||
# save the results
|
||||
print(" > Saving output to {}".format(args.out_path))
|
||||
synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,332 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from trainer.io import copy_model_files, save_best_model, save_checkpoint
|
||||
from trainer.torch import NoamLR
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
|
||||
from TTS.encoder.dataset import EncoderDataset
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.encoder.utils.training import init_training
|
||||
from TTS.encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||
from TTS.utils.samplers import PerfectBatchSampler
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
|
||||
|
||||
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
|
||||
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
|
||||
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
|
||||
|
||||
dataset = EncoderDataset(
|
||||
c,
|
||||
ap,
|
||||
meta_data_eval if is_val else meta_data_train,
|
||||
voice_len=c.voice_len,
|
||||
num_utter_per_class=num_utter_per_class,
|
||||
num_classes_in_batch=num_classes_in_batch,
|
||||
verbose=verbose,
|
||||
augmentation_config=c.audio_augmentation if not is_val else None,
|
||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||
)
|
||||
# get classes list
|
||||
classes = dataset.get_class_list()
|
||||
|
||||
sampler = PerfectBatchSampler(
|
||||
dataset.items,
|
||||
classes,
|
||||
batch_size=num_classes_in_batch * num_utter_per_class, # total batch size
|
||||
num_classes_in_batch=num_classes_in_batch,
|
||||
num_gpus=1,
|
||||
shuffle=not is_val,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if len(classes) < num_classes_in_batch:
|
||||
if is_val:
|
||||
raise RuntimeError(
|
||||
f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !"
|
||||
)
|
||||
|
||||
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
|
||||
if is_val:
|
||||
dataset.set_classes(train_classes)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
num_workers=c.num_loader_workers,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
)
|
||||
|
||||
return loader, classes, dataset.get_map_classid_to_classname()
|
||||
|
||||
|
||||
def evaluation(model, criterion, data_loader, global_step):
|
||||
eval_loss = 0
|
||||
for _, data in enumerate(data_loader):
|
||||
with torch.no_grad():
|
||||
# setup input data
|
||||
inputs, labels = data
|
||||
|
||||
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
|
||||
labels = torch.transpose(
|
||||
labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1
|
||||
).reshape(labels.shape)
|
||||
inputs = torch.transpose(
|
||||
inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1
|
||||
).reshape(inputs.shape)
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
labels = labels.cuda(non_blocking=True)
|
||||
|
||||
# forward pass model
|
||||
outputs = model(inputs)
|
||||
|
||||
# loss computation
|
||||
loss = criterion(
|
||||
outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels
|
||||
)
|
||||
|
||||
eval_loss += loss.item()
|
||||
|
||||
eval_avg_loss = eval_loss / len(data_loader)
|
||||
# save stats
|
||||
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
|
||||
# plot the last batch in the evaluation
|
||||
figures = {
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
||||
}
|
||||
dashboard_logger.eval_figures(global_step, figures)
|
||||
return eval_avg_loss
|
||||
|
||||
|
||||
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
|
||||
model.train()
|
||||
best_loss = {"train_loss": None, "eval_loss": float("inf")}
|
||||
avg_loader_time = 0
|
||||
end_time = time.time()
|
||||
for epoch in range(c.epochs):
|
||||
tot_loss = 0
|
||||
epoch_time = 0
|
||||
for _, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# setup input data
|
||||
inputs, labels = data
|
||||
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
|
||||
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(
|
||||
labels.shape
|
||||
)
|
||||
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(
|
||||
inputs.shape
|
||||
)
|
||||
# ToDo: move it to a unit test
|
||||
# labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
|
||||
# inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
|
||||
# idx = 0
|
||||
# for j in range(0, c.num_classes_in_batch, 1):
|
||||
# for i in range(j, len(labels), c.num_classes_in_batch):
|
||||
# if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])):
|
||||
# print("Invalid")
|
||||
# print(labels)
|
||||
# exit()
|
||||
# idx += 1
|
||||
# labels = labels_converted
|
||||
# inputs = inputs_converted
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
labels = labels.cuda(non_blocking=True)
|
||||
|
||||
# forward pass model
|
||||
outputs = model(inputs)
|
||||
|
||||
# loss computation
|
||||
loss = criterion(
|
||||
outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels
|
||||
)
|
||||
loss.backward()
|
||||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# acumulate the total epoch loss
|
||||
tot_loss += loss.item()
|
||||
|
||||
# Averaged Loader Time
|
||||
num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
|
||||
avg_loader_time = (
|
||||
1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
|
||||
if avg_loader_time != 0
|
||||
else loader_time
|
||||
)
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
if global_step % c.steps_plot_stats == 0:
|
||||
# Plot Training Epoch Stats
|
||||
train_stats = {
|
||||
"loss": loss.item(),
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time,
|
||||
"avg_loader_time": avg_loader_time,
|
||||
}
|
||||
dashboard_logger.train_epoch_stats(global_step, train_stats)
|
||||
figures = {
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
||||
}
|
||||
dashboard_logger.train_figures(global_step, figures)
|
||||
|
||||
if global_step % c.print_step == 0:
|
||||
print(
|
||||
" | > Step:{} Loss:{:.5f} GradNorm:{:.5f} "
|
||||
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
|
||||
global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if global_step % c.save_step == 0:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
print("")
|
||||
print(
|
||||
">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
|
||||
"EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
|
||||
epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
# evaluation
|
||||
if c.run_eval:
|
||||
model.eval()
|
||||
eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
|
||||
print("\n\n")
|
||||
print("--> EVAL PERFORMANCE")
|
||||
print(
|
||||
" | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss),
|
||||
flush=True,
|
||||
)
|
||||
# save the best checkpoint
|
||||
best_loss = save_best_model(
|
||||
{"train_loss": None, "eval_loss": eval_loss},
|
||||
best_loss,
|
||||
c,
|
||||
model,
|
||||
optimizer,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
criterion=criterion.state_dict(),
|
||||
)
|
||||
model.train()
|
||||
|
||||
return best_loss, global_step
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train
|
||||
global meta_data_eval
|
||||
global train_classes
|
||||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
model = setup_encoder_model(c)
|
||||
|
||||
optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
|
||||
|
||||
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
|
||||
if c.run_eval:
|
||||
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
|
||||
else:
|
||||
eval_data_loader = None
|
||||
|
||||
num_classes = len(train_classes)
|
||||
criterion = model.get_criterion(c, num_classes)
|
||||
|
||||
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
|
||||
c.map_classid_to_classname = map_classid_to_classname
|
||||
copy_model_files(c, OUT_PATH, new_fields={})
|
||||
|
||||
if args.restore_path:
|
||||
criterion, args.restore_step = model.load_checkpoint(
|
||||
c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion
|
||||
)
|
||||
print(" > Model restored from step %d" % args.restore_step, flush=True)
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
if c.lr_decay:
|
||||
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
global_step = args.restore_step
|
||||
_, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training()
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainTTSArgs(TrainerArgs):
|
||||
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||
|
||||
|
||||
def main():
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
# init trainer args
|
||||
train_args = TrainTTSArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
||||
# override trainer args from comman-line args
|
||||
args, config_overrides = parser.parse_known_args()
|
||||
train_args.parse_args(args)
|
||||
|
||||
# load config.json and register
|
||||
if args.config_path or args.continue_path:
|
||||
if args.config_path:
|
||||
# init from a file
|
||||
config = load_config(args.config_path)
|
||||
if len(config_overrides) > 0:
|
||||
config.parse_known_args(config_overrides, relaxed_parser=True)
|
||||
elif args.continue_path:
|
||||
# continue from a prev experiment
|
||||
config = load_config(os.path.join(args.continue_path, "config.json"))
|
||||
if len(config_overrides) > 0:
|
||||
config.parse_known_args(config_overrides, relaxed_parser=True)
|
||||
else:
|
||||
# init from console args
|
||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
||||
|
||||
config_base = BaseTrainingConfig()
|
||||
config_base.parse_known_args(config_overrides)
|
||||
config = register_config(config_base.model)()
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
config.datasets,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
model = setup_model(config, train_samples + eval_samples)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
train_args,
|
||||
model.config,
|
||||
config.output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
parse_command_line_args=False,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainVocoderArgs(TrainerArgs):
|
||||
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||
|
||||
|
||||
def main():
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
# init trainer args
|
||||
train_args = TrainVocoderArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
||||
# override trainer args from comman-line args
|
||||
args, config_overrides = parser.parse_known_args()
|
||||
train_args.parse_args(args)
|
||||
|
||||
# load config.json and register
|
||||
if args.config_path or args.continue_path:
|
||||
if args.config_path:
|
||||
# init from a file
|
||||
config = load_config(args.config_path)
|
||||
if len(config_overrides) > 0:
|
||||
config.parse_known_args(config_overrides, relaxed_parser=True)
|
||||
elif args.continue_path:
|
||||
# continue from a prev experiment
|
||||
config = load_config(os.path.join(args.continue_path, "config.json"))
|
||||
if len(config_overrides) > 0:
|
||||
config.parse_known_args(config_overrides, relaxed_parser=True)
|
||||
else:
|
||||
# init from console args
|
||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
||||
|
||||
config_base = BaseTrainingConfig()
|
||||
config_base.parse_known_args(config_overrides)
|
||||
config = register_config(config_base.model)()
|
||||
|
||||
# load training samples
|
||||
if "feature_path" in config and config.feature_path:
|
||||
# load pre-computed features
|
||||
print(f" > Loading features from: {config.feature_path}")
|
||||
eval_samples, train_samples = load_wav_feat_data(config.data_path, config.feature_path, config.eval_split_size)
|
||||
else:
|
||||
# load data raw wav files
|
||||
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**config.audio)
|
||||
|
||||
# init the model from config
|
||||
model = setup_model(config)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
train_args,
|
||||
config,
|
||||
config.output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
parse_command_line_args=False,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,103 @@
|
||||
"""Search a good noise schedule for WaveGrad for a given number of inference iterations"""
|
||||
import argparse
|
||||
from itertools import product as cartesian_product
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.models import setup_model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
||||
parser.add_argument("--data_path", type=str, help="Path to data directory.")
|
||||
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
|
||||
parser.add_argument(
|
||||
"--num_iter",
|
||||
type=int,
|
||||
help="Number of model inference iterations that you like to optimize noise schedule for.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.")
|
||||
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
|
||||
parser.add_argument(
|
||||
"--search_depth",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Search granularity. Increasing this increases the run-time exponentially.",
|
||||
)
|
||||
|
||||
# load config
|
||||
args = parser.parse_args()
|
||||
config = load_config(args.config_path)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**config.audio)
|
||||
|
||||
# load dataset
|
||||
_, train_data = load_wav_data(args.data_path, 0)
|
||||
train_data = train_data[: args.num_samples]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=train_data,
|
||||
seq_len=-1,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
is_training=True,
|
||||
return_segments=False,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=True,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_full_clips,
|
||||
drop_last=False,
|
||||
num_workers=config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
# setup the model
|
||||
model = setup_model(config)
|
||||
if args.use_cuda:
|
||||
model.cuda()
|
||||
|
||||
# setup optimization parameters
|
||||
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
||||
print(f" > base values: {base_values}")
|
||||
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
||||
best_error = float("inf")
|
||||
best_schedule = None # pylint: disable=C0103
|
||||
total_search_iter = len(base_values) ** args.num_iter
|
||||
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
||||
beta = exponents * base
|
||||
model.compute_noise_level(beta)
|
||||
for data in loader:
|
||||
mel, audio = data
|
||||
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
||||
|
||||
if args.use_cuda:
|
||||
y_hat = y_hat.cpu()
|
||||
y_hat = y_hat.numpy()
|
||||
|
||||
mel_hat = []
|
||||
for i in range(y_hat.shape[0]):
|
||||
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
|
||||
mel_hat.append(torch.from_numpy(m))
|
||||
|
||||
mel_hat = torch.stack(mel_hat)
|
||||
mse = torch.sum((mel - mel_hat) ** 2).mean()
|
||||
if mse.item() < best_error:
|
||||
best_error = mse.item()
|
||||
best_schedule = {"beta": beta}
|
||||
print(f" > Found a better schedule. - MSE: {mse.item()}")
|
||||
np.save(args.output_path, best_schedule)
|
||||
@@ -0,0 +1,135 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
import fsspec
|
||||
import yaml
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config.shared_configs import *
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
def read_json_with_comments(json_path):
|
||||
"""for backward compat."""
|
||||
# fallback to json
|
||||
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
# handle comments but not urls with //
|
||||
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
|
||||
return json.loads(input_str)
|
||||
|
||||
def register_config(model_name: str) -> Coqpit:
|
||||
"""Find the right config for the given model name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: No matching config for the model name.
|
||||
|
||||
Returns:
|
||||
Coqpit: config class.
|
||||
"""
|
||||
config_class = None
|
||||
config_name = model_name + "_config"
|
||||
|
||||
# TODO: fix this
|
||||
if model_name == "xtts":
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
|
||||
config_class = XttsConfig
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"]
|
||||
for path in paths:
|
||||
try:
|
||||
config_class = find_module(path, config_name)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
if config_class is None:
|
||||
raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.")
|
||||
return config_class
|
||||
|
||||
|
||||
def _process_model_name(config_dict: Dict) -> str:
|
||||
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
|
||||
|
||||
Args:
|
||||
config_dict (Dict): A dictionary including the config fields.
|
||||
|
||||
Returns:
|
||||
str: Formatted modelname.
|
||||
"""
|
||||
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
|
||||
model_name = model_name.replace("_generator", "").replace("_discriminator", "")
|
||||
return model_name
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Coqpit:
|
||||
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
|
||||
to find the corresponding Config class. Then initialize the Config.
|
||||
|
||||
Args:
|
||||
config_path (str): path to the config file.
|
||||
|
||||
Raises:
|
||||
TypeError: given config file has an unknown type.
|
||||
|
||||
Returns:
|
||||
Coqpit: TTS config object.
|
||||
"""
|
||||
config_dict = {}
|
||||
ext = os.path.splitext(config_path)[1]
|
||||
if ext in (".yml", ".yaml"):
|
||||
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
elif ext == ".json":
|
||||
try:
|
||||
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except json.decoder.JSONDecodeError:
|
||||
# backwards compat.
|
||||
data = read_json_with_comments(config_path)
|
||||
else:
|
||||
raise TypeError(f" [!] Unknown config file type {ext}")
|
||||
config_dict.update(data)
|
||||
model_name = _process_model_name(config_dict)
|
||||
config_class = register_config(model_name.lower())
|
||||
config = config_class()
|
||||
config.from_dict(config_dict)
|
||||
return config
|
||||
|
||||
|
||||
def check_config_and_model_args(config, arg_name, value):
|
||||
"""Check the give argument in `config.model_args` if exist or in `config` for
|
||||
the given value.
|
||||
|
||||
Return False if the argument does not exist in `config.model_args` or `config`.
|
||||
This is to patch up the compatibility between models with and without `model_args`.
|
||||
|
||||
TODO: Remove this in the future with a unified approach.
|
||||
"""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name] == value
|
||||
if hasattr(config, arg_name):
|
||||
return config[arg_name] == value
|
||||
return False
|
||||
|
||||
|
||||
def get_from_config_or_model_args(config, arg_name):
|
||||
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name]
|
||||
return config[arg_name]
|
||||
|
||||
|
||||
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
|
||||
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name]
|
||||
if hasattr(config, arg_name):
|
||||
return config[arg_name]
|
||||
return def_val
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,268 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List
|
||||
|
||||
from coqpit import Coqpit, check_argument
|
||||
from trainer import TrainerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseAudioConfig(Coqpit):
|
||||
"""Base config to definge audio processing parameters. It is used to initialize
|
||||
```TTS.utils.audio.AudioProcessor.```
|
||||
|
||||
Args:
|
||||
fft_size (int):
|
||||
Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024.
|
||||
|
||||
win_length (int):
|
||||
Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match
|
||||
```fft_size```. Defaults to 1024.
|
||||
|
||||
hop_length (int):
|
||||
Number of audio samples between adjacent STFT columns. Defaults to 1024.
|
||||
|
||||
frame_shift_ms (int):
|
||||
Set ```hop_length``` based on milliseconds and sampling rate.
|
||||
|
||||
frame_length_ms (int):
|
||||
Set ```win_length``` based on milliseconds and sampling rate.
|
||||
|
||||
stft_pad_mode (str):
|
||||
Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'.
|
||||
|
||||
sample_rate (int):
|
||||
Audio sampling rate. Defaults to 22050.
|
||||
|
||||
resample (bool):
|
||||
Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```.
|
||||
|
||||
preemphasis (float):
|
||||
Preemphasis coefficient. Defaults to 0.0.
|
||||
|
||||
ref_level_db (int): 20
|
||||
Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air.
|
||||
Defaults to 20.
|
||||
|
||||
do_sound_norm (bool):
|
||||
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
||||
|
||||
log_func (str):
|
||||
Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'.
|
||||
|
||||
do_trim_silence (bool):
|
||||
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||
|
||||
do_amp_to_db_mel (bool, optional):
|
||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||
|
||||
pitch_fmax (float, optional):
|
||||
Maximum frequency of the F0 frames. Defaults to ```640```.
|
||||
|
||||
pitch_fmin (float, optional):
|
||||
Minimum frequency of the F0 frames. Defaults to ```1```.
|
||||
|
||||
trim_db (int):
|
||||
Silence threshold used for silence trimming. Defaults to 45.
|
||||
|
||||
do_rms_norm (bool, optional):
|
||||
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
|
||||
|
||||
db_level (int, optional):
|
||||
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
|
||||
|
||||
power (float):
|
||||
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
|
||||
artifacts in the synthesized voice. Defaults to 1.5.
|
||||
|
||||
griffin_lim_iters (int):
|
||||
Number of Griffing Lim iterations. Defaults to 60.
|
||||
|
||||
num_mels (int):
|
||||
Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80.
|
||||
|
||||
mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices.
|
||||
It needs to be adjusted for a dataset. Defaults to 0.
|
||||
|
||||
mel_fmax (float):
|
||||
Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset.
|
||||
|
||||
spec_gain (int):
|
||||
Gain applied when converting amplitude to DB. Defaults to 20.
|
||||
|
||||
signal_norm (bool):
|
||||
enable/disable signal normalization. Defaults to True.
|
||||
|
||||
min_level_db (int):
|
||||
minimum db threshold for the computed melspectrograms. Defaults to -100.
|
||||
|
||||
symmetric_norm (bool):
|
||||
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else
|
||||
[0, k], Defaults to True.
|
||||
|
||||
max_norm (float):
|
||||
```k``` defining the normalization range. Defaults to 4.0.
|
||||
|
||||
clip_norm (bool):
|
||||
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||
|
||||
stats_path (str):
|
||||
Path to the computed stats file. Defaults to None.
|
||||
"""
|
||||
|
||||
# stft parameters
|
||||
fft_size: int = 1024
|
||||
win_length: int = 1024
|
||||
hop_length: int = 256
|
||||
frame_shift_ms: int = None
|
||||
frame_length_ms: int = None
|
||||
stft_pad_mode: str = "reflect"
|
||||
# audio processing parameters
|
||||
sample_rate: int = 22050
|
||||
resample: bool = False
|
||||
preemphasis: float = 0.0
|
||||
ref_level_db: int = 20
|
||||
do_sound_norm: bool = False
|
||||
log_func: str = "np.log10"
|
||||
# silence trimming
|
||||
do_trim_silence: bool = True
|
||||
trim_db: int = 45
|
||||
# rms volume normalization
|
||||
do_rms_norm: bool = False
|
||||
db_level: float = None
|
||||
# griffin-lim params
|
||||
power: float = 1.5
|
||||
griffin_lim_iters: int = 60
|
||||
# mel-spec params
|
||||
num_mels: int = 80
|
||||
mel_fmin: float = 0.0
|
||||
mel_fmax: float = None
|
||||
spec_gain: int = 20
|
||||
do_amp_to_db_linear: bool = True
|
||||
do_amp_to_db_mel: bool = True
|
||||
# f0 params
|
||||
pitch_fmax: float = 640.0
|
||||
pitch_fmin: float = 1.0
|
||||
# normalization params
|
||||
signal_norm: bool = True
|
||||
min_level_db: int = -100
|
||||
symmetric_norm: bool = True
|
||||
max_norm: float = 4.0
|
||||
clip_norm: bool = True
|
||||
stats_path: str = None
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
):
|
||||
"""Check config fields"""
|
||||
c = asdict(self)
|
||||
check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056)
|
||||
check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058)
|
||||
check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000)
|
||||
check_argument(
|
||||
"frame_length_ms",
|
||||
c,
|
||||
restricted=True,
|
||||
min_val=10,
|
||||
max_val=1000,
|
||||
alternative="win_length",
|
||||
)
|
||||
check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length")
|
||||
check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1)
|
||||
check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10)
|
||||
check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000)
|
||||
check_argument("power", c, restricted=True, min_val=1, max_val=5)
|
||||
check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000)
|
||||
|
||||
# normalization parameters
|
||||
check_argument("signal_norm", c, restricted=True)
|
||||
check_argument("symmetric_norm", c, restricted=True)
|
||||
check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000)
|
||||
check_argument("clip_norm", c, restricted=True)
|
||||
check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000)
|
||||
check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True)
|
||||
check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100)
|
||||
check_argument("do_trim_silence", c, restricted=True)
|
||||
check_argument("trim_db", c, restricted=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetConfig(Coqpit):
|
||||
"""Base config for TTS datasets.
|
||||
|
||||
Args:
|
||||
formatter (str):
|
||||
Formatter name that defines used formatter in ```TTS.tts.datasets.formatter```. Defaults to `""`.
|
||||
|
||||
dataset_name (str):
|
||||
Unique name for the dataset. Defaults to `""`.
|
||||
|
||||
path (str):
|
||||
Root path to the dataset files. Defaults to `""`.
|
||||
|
||||
meta_file_train (str):
|
||||
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
|
||||
Defaults to `""`.
|
||||
|
||||
ignored_speakers (List):
|
||||
List of speakers IDs that are not used at the training. Default None.
|
||||
|
||||
language (str):
|
||||
Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`.
|
||||
|
||||
phonemizer (str):
|
||||
Phonemizer used for that dataset's language. By default it uses `DEF_LANG_TO_PHONEMIZER`. Defaults to `""`.
|
||||
|
||||
meta_file_val (str):
|
||||
Name of the dataset meta file that defines the instances used at validation.
|
||||
|
||||
meta_file_attn_mask (str):
|
||||
Path to the file that lists the attention mask files used with models that require attention masks to
|
||||
train the duration predictor.
|
||||
"""
|
||||
|
||||
formatter: str = ""
|
||||
dataset_name: str = ""
|
||||
path: str = ""
|
||||
meta_file_train: str = ""
|
||||
ignored_speakers: List[str] = None
|
||||
language: str = ""
|
||||
phonemizer: str = ""
|
||||
meta_file_val: str = ""
|
||||
meta_file_attn_mask: str = ""
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
):
|
||||
"""Check config fields"""
|
||||
c = asdict(self)
|
||||
check_argument("formatter", c, restricted=True)
|
||||
check_argument("path", c, restricted=True)
|
||||
check_argument("meta_file_train", c, restricted=True)
|
||||
check_argument("meta_file_val", c, restricted=False)
|
||||
check_argument("meta_file_attn_mask", c, restricted=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseTrainingConfig(TrainerConfig):
|
||||
"""Base config to define the basic 🐸TTS training parameters that are shared
|
||||
among all the models. It is based on ```Trainer.TrainingConfig```.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Name of the model that is used in the training.
|
||||
|
||||
num_loader_workers (int):
|
||||
Number of workers for training time dataloader.
|
||||
|
||||
num_eval_loader_workers (int):
|
||||
Number of workers for evaluation time dataloader.
|
||||
"""
|
||||
|
||||
model: str = None
|
||||
# dataloading
|
||||
num_loader_workers: int = 0
|
||||
num_eval_loader_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
@@ -0,0 +1,18 @@
|
||||
### Speaker Encoder
|
||||
|
||||
This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding.
|
||||
|
||||
With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart.
|
||||
|
||||
Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q).
|
||||
|
||||

|
||||
|
||||
Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page.
|
||||
|
||||
To run the code, you need to follow the same flow as in TTS.
|
||||
|
||||
- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model.
|
||||
- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360```
|
||||
- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files.
|
||||
- Watch training on Tensorboard as in TTS
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,61 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from coqpit import MISSING
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseEncoderConfig(BaseTrainingConfig):
|
||||
"""Defines parameters for a Generic Encoder model."""
|
||||
|
||||
model: str = None
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# model params
|
||||
model_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"model_name": "lstm",
|
||||
"input_dim": 80,
|
||||
"proj_dim": 256,
|
||||
"lstm_dim": 768,
|
||||
"num_lstm_layers": 3,
|
||||
"use_lstm_with_projection": True,
|
||||
}
|
||||
)
|
||||
|
||||
audio_augmentation: Dict = field(default_factory=lambda: {})
|
||||
|
||||
# training params
|
||||
epochs: int = 10000
|
||||
loss: str = "angleproto"
|
||||
grad_clip: float = 3.0
|
||||
lr: float = 0.0001
|
||||
optimizer: str = "radam"
|
||||
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
|
||||
lr_decay: bool = False
|
||||
warmup_steps: int = 4000
|
||||
|
||||
# logging params
|
||||
tb_model_param_stats: bool = False
|
||||
steps_plot_stats: int = 10
|
||||
save_step: int = 1000
|
||||
print_step: int = 20
|
||||
run_eval: bool = False
|
||||
|
||||
# data loader
|
||||
num_classes_in_batch: int = MISSING
|
||||
num_utter_per_class: int = MISSING
|
||||
eval_num_classes_in_batch: int = None
|
||||
eval_num_utter_per_class: int = None
|
||||
|
||||
num_loader_workers: int = MISSING
|
||||
voice_len: float = 1.6
|
||||
|
||||
def check_values(self):
|
||||
super().check_values()
|
||||
c = asdict(self)
|
||||
assert (
|
||||
c["model_params"]["input_dim"] == self.audio.num_mels
|
||||
), " [!] model input dimendion must be equal to melspectrogram dimension."
|
||||
@@ -0,0 +1,12 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmotionEncoderConfig(BaseEncoderConfig):
|
||||
"""Defines parameters for Emotion Encoder model."""
|
||||
|
||||
model: str = "emotion_encoder"
|
||||
map_classid_to_classname: dict = None
|
||||
class_name_key: str = "emotion_name"
|
||||
@@ -0,0 +1,11 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerEncoderConfig(BaseEncoderConfig):
|
||||
"""Defines parameters for Speaker Encoder model."""
|
||||
|
||||
model: str = "speaker_encoder"
|
||||
class_name_key: str = "speaker_name"
|
||||
@@ -0,0 +1,147 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.encoder.utils.generic_utils import AugmentWAV
|
||||
|
||||
|
||||
class EncoderDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
ap,
|
||||
meta_data,
|
||||
voice_len=1.6,
|
||||
num_classes_in_batch=64,
|
||||
num_utter_per_class=10,
|
||||
verbose=False,
|
||||
augmentation_config=None,
|
||||
use_torch_spec=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
seq_len (int): voice segment length in seconds.
|
||||
verbose (bool): print diagnostic information.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.items = meta_data
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.seq_len = int(voice_len * self.sample_rate)
|
||||
self.num_utter_per_class = num_utter_per_class
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.classes, self.items = self.__parse_items()
|
||||
|
||||
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
|
||||
|
||||
# Data Augmentation
|
||||
self.augmentator = None
|
||||
self.gaussian_augmentation_config = None
|
||||
if augmentation_config:
|
||||
self.data_augmentation_p = augmentation_config["p"]
|
||||
if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
|
||||
self.augmentator = AugmentWAV(ap, augmentation_config)
|
||||
|
||||
if "gaussian" in augmentation_config.keys():
|
||||
self.gaussian_augmentation_config = augmentation_config["gaussian"]
|
||||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(f" | > Classes per Batch: {num_classes_in_batch}")
|
||||
print(f" | > Number of instances : {len(self.items)}")
|
||||
print(f" | > Sequence length: {self.seq_len}")
|
||||
print(f" | > Num Classes: {len(self.classes)}")
|
||||
print(f" | > Classes: {self.classes}")
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
||||
return audio
|
||||
|
||||
def __parse_items(self):
|
||||
class_to_utters = {}
|
||||
for item in self.items:
|
||||
path_ = item["audio_file"]
|
||||
class_name = item[self.config.class_name_key]
|
||||
if class_name in class_to_utters.keys():
|
||||
class_to_utters[class_name].append(path_)
|
||||
else:
|
||||
class_to_utters[class_name] = [
|
||||
path_,
|
||||
]
|
||||
|
||||
# skip classes with number of samples >= self.num_utter_per_class
|
||||
class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class}
|
||||
|
||||
classes = list(class_to_utters.keys())
|
||||
classes.sort()
|
||||
|
||||
new_items = []
|
||||
for item in self.items:
|
||||
path_ = item["audio_file"]
|
||||
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
|
||||
# ignore filtered classes
|
||||
if class_name not in classes:
|
||||
continue
|
||||
# ignore small audios
|
||||
if self.load_wav(path_).shape[0] - self.seq_len <= 0:
|
||||
continue
|
||||
|
||||
new_items.append({"wav_file_path": path_, "class_name": class_name})
|
||||
|
||||
return classes, new_items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.classes)
|
||||
|
||||
def get_class_list(self):
|
||||
return self.classes
|
||||
|
||||
def set_classes(self, classes):
|
||||
self.classes = classes
|
||||
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
|
||||
|
||||
def get_map_classid_to_classname(self):
|
||||
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.items[idx]
|
||||
|
||||
def collate_fn(self, batch):
|
||||
# get the batch class_ids
|
||||
labels = []
|
||||
feats = []
|
||||
for item in batch:
|
||||
utter_path = item["wav_file_path"]
|
||||
class_name = item["class_name"]
|
||||
|
||||
# get classid
|
||||
class_id = self.classname_to_classid[class_name]
|
||||
# load wav file
|
||||
wav = self.load_wav(utter_path)
|
||||
offset = random.randint(0, wav.shape[0] - self.seq_len)
|
||||
wav = wav[offset : offset + self.seq_len]
|
||||
|
||||
if self.augmentator is not None and self.data_augmentation_p:
|
||||
if random.random() < self.data_augmentation_p:
|
||||
wav = self.augmentator.apply_one(wav)
|
||||
|
||||
if not self.use_torch_spec:
|
||||
mel = self.ap.melspectrogram(wav)
|
||||
feats.append(torch.FloatTensor(mel))
|
||||
else:
|
||||
feats.append(torch.FloatTensor(wav))
|
||||
|
||||
labels.append(class_id)
|
||||
|
||||
feats = torch.stack(feats)
|
||||
labels = torch.LongTensor(labels)
|
||||
|
||||
return feats, labels
|
||||
@@ -0,0 +1,226 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
# adapted from https://github.com/cvqluu/GE2E-Loss
|
||||
class GE2ELoss(nn.Module):
|
||||
def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
|
||||
"""
|
||||
Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
|
||||
Accepts an input of size (N, M, D)
|
||||
where N is the number of speakers in the batch,
|
||||
M is the number of utterances per speaker,
|
||||
and D is the dimensionality of the embedding vector (e.g. d-vector)
|
||||
Args:
|
||||
- init_w (float): defines the initial value of w in Equation (5) of [1]
|
||||
- init_b (float): definies the initial value of b in Equation (5) of [1]
|
||||
"""
|
||||
super().__init__()
|
||||
# pylint: disable=E1102
|
||||
self.w = nn.Parameter(torch.tensor(init_w))
|
||||
# pylint: disable=E1102
|
||||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.loss_method = loss_method
|
||||
|
||||
print(" > Initialized Generalized End-to-End loss")
|
||||
|
||||
assert self.loss_method in ["softmax", "contrast"]
|
||||
|
||||
if self.loss_method == "softmax":
|
||||
self.embed_loss = self.embed_loss_softmax
|
||||
if self.loss_method == "contrast":
|
||||
self.embed_loss = self.embed_loss_contrast
|
||||
|
||||
# pylint: disable=R0201
|
||||
def calc_new_centroids(self, dvecs, centroids, spkr, utt):
|
||||
"""
|
||||
Calculates the new centroids excluding the reference utterance
|
||||
"""
|
||||
excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :]))
|
||||
excl = torch.mean(excl, 0)
|
||||
new_centroids = []
|
||||
for i, centroid in enumerate(centroids):
|
||||
if i == spkr:
|
||||
new_centroids.append(excl)
|
||||
else:
|
||||
new_centroids.append(centroid)
|
||||
return torch.stack(new_centroids)
|
||||
|
||||
def calc_cosine_sim(self, dvecs, centroids):
|
||||
"""
|
||||
Make the cosine similarity matrix with dims (N,M,N)
|
||||
"""
|
||||
cos_sim_matrix = []
|
||||
for spkr_idx, speaker in enumerate(dvecs):
|
||||
cs_row = []
|
||||
for utt_idx, utterance in enumerate(speaker):
|
||||
new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx)
|
||||
# vector based cosine similarity for speed
|
||||
cs_row.append(
|
||||
torch.clamp(
|
||||
torch.mm(
|
||||
utterance.unsqueeze(1).transpose(0, 1),
|
||||
new_centroids.transpose(0, 1),
|
||||
)
|
||||
/ (torch.norm(utterance) * torch.norm(new_centroids, dim=1)),
|
||||
1e-6,
|
||||
)
|
||||
)
|
||||
cs_row = torch.cat(cs_row, dim=0)
|
||||
cos_sim_matrix.append(cs_row)
|
||||
return torch.stack(cos_sim_matrix)
|
||||
|
||||
# pylint: disable=R0201
|
||||
def embed_loss_softmax(self, dvecs, cos_sim_matrix):
|
||||
"""
|
||||
Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
|
||||
"""
|
||||
N, M, _ = dvecs.shape
|
||||
L = []
|
||||
for j in range(N):
|
||||
L_row = []
|
||||
for i in range(M):
|
||||
L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j])
|
||||
L_row = torch.stack(L_row)
|
||||
L.append(L_row)
|
||||
return torch.stack(L)
|
||||
|
||||
# pylint: disable=R0201
|
||||
def embed_loss_contrast(self, dvecs, cos_sim_matrix):
|
||||
"""
|
||||
Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
|
||||
"""
|
||||
N, M, _ = dvecs.shape
|
||||
L = []
|
||||
for j in range(N):
|
||||
L_row = []
|
||||
for i in range(M):
|
||||
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
|
||||
excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]))
|
||||
L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids))
|
||||
L_row = torch.stack(L_row)
|
||||
L.append(L_row)
|
||||
return torch.stack(L)
|
||||
|
||||
def forward(self, x, _label=None):
|
||||
"""
|
||||
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||
"""
|
||||
|
||||
assert x.size()[1] >= 2
|
||||
|
||||
centroids = torch.mean(x, 1)
|
||||
cos_sim_matrix = self.calc_cosine_sim(x, centroids)
|
||||
torch.clamp(self.w, 1e-6)
|
||||
cos_sim_matrix = self.w * cos_sim_matrix + self.b
|
||||
L = self.embed_loss(x, cos_sim_matrix)
|
||||
return L.mean()
|
||||
|
||||
|
||||
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
|
||||
class AngleProtoLoss(nn.Module):
|
||||
"""
|
||||
Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
|
||||
Accepts an input of size (N, M, D)
|
||||
where N is the number of speakers in the batch,
|
||||
M is the number of utterances per speaker,
|
||||
and D is the dimensionality of the embedding vector
|
||||
Args:
|
||||
- init_w (float): defines the initial value of w
|
||||
- init_b (float): definies the initial value of b
|
||||
"""
|
||||
|
||||
def __init__(self, init_w=10.0, init_b=-5.0):
|
||||
super().__init__()
|
||||
# pylint: disable=E1102
|
||||
self.w = nn.Parameter(torch.tensor(init_w))
|
||||
# pylint: disable=E1102
|
||||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
print(" > Initialized Angular Prototypical loss")
|
||||
|
||||
def forward(self, x, _label=None):
|
||||
"""
|
||||
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||
"""
|
||||
|
||||
assert x.size()[1] >= 2
|
||||
|
||||
out_anchor = torch.mean(x[:, 1:, :], 1)
|
||||
out_positive = x[:, 0, :]
|
||||
num_speakers = out_anchor.size()[0]
|
||||
|
||||
cos_sim_matrix = F.cosine_similarity(
|
||||
out_positive.unsqueeze(-1).expand(-1, -1, num_speakers),
|
||||
out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2),
|
||||
)
|
||||
torch.clamp(self.w, 1e-6)
|
||||
cos_sim_matrix = cos_sim_matrix * self.w + self.b
|
||||
label = torch.arange(num_speakers).to(cos_sim_matrix.device)
|
||||
L = self.criterion(cos_sim_matrix, label)
|
||||
return L
|
||||
|
||||
|
||||
class SoftmaxLoss(nn.Module):
|
||||
"""
|
||||
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
|
||||
Args:
|
||||
- embedding_dim (float): speaker embedding dim
|
||||
- n_speakers (float): number of speakers
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, n_speakers):
|
||||
super().__init__()
|
||||
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
self.fc = nn.Linear(embedding_dim, n_speakers)
|
||||
|
||||
print("Initialised Softmax Loss")
|
||||
|
||||
def forward(self, x, label=None):
|
||||
# reshape for compatibility
|
||||
x = x.reshape(-1, x.size()[-1])
|
||||
label = label.reshape(-1)
|
||||
|
||||
x = self.fc(x)
|
||||
L = self.criterion(x, label)
|
||||
|
||||
return L
|
||||
|
||||
def inference(self, embedding):
|
||||
x = self.fc(embedding)
|
||||
activations = torch.nn.functional.softmax(x, dim=1).squeeze(0)
|
||||
class_id = torch.argmax(activations)
|
||||
return class_id
|
||||
|
||||
|
||||
class SoftmaxAngleProtoLoss(nn.Module):
|
||||
"""
|
||||
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
|
||||
Args:
|
||||
- embedding_dim (float): speaker embedding dim
|
||||
- n_speakers (float): number of speakers
|
||||
- init_w (float): defines the initial value of w
|
||||
- init_b (float): definies the initial value of b
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
|
||||
super().__init__()
|
||||
|
||||
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
|
||||
self.angleproto = AngleProtoLoss(init_w, init_b)
|
||||
|
||||
print("Initialised SoftmaxAnglePrototypical Loss")
|
||||
|
||||
def forward(self, x, label=None):
|
||||
"""
|
||||
Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||
"""
|
||||
|
||||
Lp = self.angleproto(x)
|
||||
|
||||
Ls = self.softmax(x, label)
|
||||
|
||||
return Ls + Lp
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,161 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.utils.generic_utils import set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
|
||||
class BaseEncoder(nn.Module):
|
||||
"""Base `encoder` class. Every new `encoder` model must inherit this.
|
||||
|
||||
It defines common `encoder` specific functions.
|
||||
"""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(self):
|
||||
super(BaseEncoder, self).__init__()
|
||||
|
||||
def get_torch_mel_spectrogram_class(self, audio_config):
|
||||
return torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, l2_norm=True):
|
||||
return self.forward(x, l2_norm)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
# map to the waveform size
|
||||
if self.use_torch_spec:
|
||||
num_frames = num_frames * self.audio_config["hop_length"]
|
||||
|
||||
max_len = x.shape[1]
|
||||
|
||||
if max_len < num_frames:
|
||||
num_frames = max_len
|
||||
|
||||
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
|
||||
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
return embeddings
|
||||
|
||||
def get_criterion(self, c: Coqpit, num_classes=None):
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
return criterion
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config: Coqpit,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False,
|
||||
use_cuda: bool = False,
|
||||
criterion=None,
|
||||
cache=False,
|
||||
):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
print(" > Model fully restored. ")
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"], c)
|
||||
self.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# load the criterion for restore_path
|
||||
if criterion is not None and "criterion" in state:
|
||||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
print(" > Criterion load ignored because of:", error)
|
||||
|
||||
# instance and load the criterion for the encoder classifier in inference time
|
||||
if (
|
||||
eval
|
||||
and criterion is None
|
||||
and "criterion" in state
|
||||
and getattr(config, "map_classid_to_classname", None) is not None
|
||||
):
|
||||
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
if not eval:
|
||||
return criterion, state["step"]
|
||||
return criterion
|
||||
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.encoder.models.base_encoder import BaseEncoder
|
||||
|
||||
|
||||
class LSTMWithProjection(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, proj_size):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.proj_size = proj_size
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
||||
self.linear = nn.Linear(hidden_size, proj_size, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
self.lstm.flatten_parameters()
|
||||
o, (_, _) = self.lstm(x)
|
||||
return self.linear(o)
|
||||
|
||||
|
||||
class LSTMWithoutProjection(nn.Module):
|
||||
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
|
||||
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
_, (hidden, _) = self.lstm(x)
|
||||
return self.relu(self.linear(hidden[-1]))
|
||||
|
||||
|
||||
class LSTMSpeakerEncoder(BaseEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
proj_dim=256,
|
||||
lstm_dim=768,
|
||||
num_lstm_layers=3,
|
||||
use_lstm_with_projection=True,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_lstm_with_projection = use_lstm_with_projection
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
self.proj_dim = proj_dim
|
||||
|
||||
layers = []
|
||||
# choise LSTM layer
|
||||
if use_lstm_with_projection:
|
||||
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
|
||||
for _ in range(num_lstm_layers - 1):
|
||||
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
else:
|
||||
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
for name, param in self.layers.named_parameters():
|
||||
if "bias" in name:
|
||||
nn.init.constant_(param, 0.0)
|
||||
elif "weight" in name:
|
||||
nn.init.xavier_normal_(param)
|
||||
|
||||
def forward(self, x, l2_norm=True):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if self.use_torch_spec:
|
||||
x.squeeze_(1)
|
||||
x = self.torch_spec(x)
|
||||
x = self.instancenorm(x).transpose(1, 2)
|
||||
d = self.layers(x)
|
||||
if self.use_lstm_with_projection:
|
||||
d = d[:, -1]
|
||||
if l2_norm:
|
||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
return d
|
||||
@@ -0,0 +1,198 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.encoder.models.base_encoder import BaseEncoder
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
|
||||
class SEBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
||||
super(SEBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se = SELayer(planes, reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.relu(out)
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNetSpeakerEncoder(BaseEncoder):
|
||||
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
|
||||
Adapted from: https://github.com/clovaai/voxceleb_trainer
|
||||
"""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(
|
||||
self,
|
||||
input_dim=64,
|
||||
proj_dim=512,
|
||||
layers=[3, 4, 6, 3],
|
||||
num_filters=[32, 64, 128, 256],
|
||||
encoder_type="ASP",
|
||||
log_input=False,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super(ResNetSpeakerEncoder, self).__init__()
|
||||
|
||||
self.encoder_type = encoder_type
|
||||
self.input_dim = input_dim
|
||||
self.log_input = log_input
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
self.proj_dim = proj_dim
|
||||
|
||||
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
||||
|
||||
self.inplanes = num_filters[0]
|
||||
self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
|
||||
self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
|
||||
self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
|
||||
self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
outmap_size = int(self.input_dim / 8)
|
||||
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
|
||||
nn.Softmax(dim=2),
|
||||
)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
out_dim = num_filters[3] * outmap_size
|
||||
elif self.encoder_type == "ASP":
|
||||
out_dim = num_filters[3] * outmap_size * 2
|
||||
else:
|
||||
raise ValueError("Undefined encoder")
|
||||
|
||||
self.fc = nn.Linear(out_dim, proj_dim)
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def create_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# pylint: disable=R0201
|
||||
def new_parameter(self, *size):
|
||||
out = nn.Parameter(torch.FloatTensor(*size))
|
||||
nn.init.xavier_normal_(out)
|
||||
return out
|
||||
|
||||
def forward(self, x, l2_norm=False):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
x.squeeze_(1)
|
||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
|
||||
if self.log_input:
|
||||
x = (x + 1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.bn1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = x.reshape(x.size()[0], -1, x.size()[-1])
|
||||
|
||||
w = self.attention(x)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
x = torch.sum(x * w, dim=2)
|
||||
elif self.encoder_type == "ASP":
|
||||
mu = torch.sum(x * w, dim=2)
|
||||
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
|
||||
x = torch.cat((mu, sg), 1)
|
||||
|
||||
x = x.view(x.size()[0], -1)
|
||||
x = self.fc(x)
|
||||
|
||||
if l2_norm:
|
||||
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
||||
return x
|
||||
@@ -0,0 +1,2 @@
|
||||
umap-learn
|
||||
numpy>=1.17.0
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,136 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
|
||||
|
||||
class AugmentWAV(object):
|
||||
def __init__(self, ap, augmentation_config):
|
||||
self.ap = ap
|
||||
self.use_additive_noise = False
|
||||
|
||||
if "additive" in augmentation_config.keys():
|
||||
self.additive_noise_config = augmentation_config["additive"]
|
||||
additive_path = self.additive_noise_config["sounds_path"]
|
||||
if additive_path:
|
||||
self.use_additive_noise = True
|
||||
# get noise types
|
||||
self.additive_noise_types = []
|
||||
for key in self.additive_noise_config.keys():
|
||||
if isinstance(self.additive_noise_config[key], dict):
|
||||
self.additive_noise_types.append(key)
|
||||
|
||||
additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True)
|
||||
|
||||
self.noise_list = {}
|
||||
|
||||
for wav_file in additive_files:
|
||||
noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0]
|
||||
# ignore not listed directories
|
||||
if noise_dir not in self.additive_noise_types:
|
||||
continue
|
||||
if not noise_dir in self.noise_list:
|
||||
self.noise_list[noise_dir] = []
|
||||
self.noise_list[noise_dir].append(wav_file)
|
||||
|
||||
print(
|
||||
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
|
||||
)
|
||||
|
||||
self.use_rir = False
|
||||
|
||||
if "rir" in augmentation_config.keys():
|
||||
self.rir_config = augmentation_config["rir"]
|
||||
if self.rir_config["rir_path"]:
|
||||
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
|
||||
self.use_rir = True
|
||||
|
||||
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
|
||||
|
||||
self.create_augmentation_global_list()
|
||||
|
||||
def create_augmentation_global_list(self):
|
||||
if self.use_additive_noise:
|
||||
self.global_noise_list = self.additive_noise_types
|
||||
else:
|
||||
self.global_noise_list = []
|
||||
if self.use_rir:
|
||||
self.global_noise_list.append("RIR_AUG")
|
||||
|
||||
def additive_noise(self, noise_type, audio):
|
||||
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
|
||||
|
||||
noise_list = random.sample(
|
||||
self.noise_list[noise_type],
|
||||
random.randint(
|
||||
self.additive_noise_config[noise_type]["min_num_noises"],
|
||||
self.additive_noise_config[noise_type]["max_num_noises"],
|
||||
),
|
||||
)
|
||||
|
||||
audio_len = audio.shape[0]
|
||||
noises_wav = None
|
||||
for noise in noise_list:
|
||||
noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len]
|
||||
|
||||
if noiseaudio.shape[0] < audio_len:
|
||||
continue
|
||||
|
||||
noise_snr = random.uniform(
|
||||
self.additive_noise_config[noise_type]["min_snr_in_db"],
|
||||
self.additive_noise_config[noise_type]["max_num_noises"],
|
||||
)
|
||||
noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4)
|
||||
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
|
||||
|
||||
if noises_wav is None:
|
||||
noises_wav = noise_wav
|
||||
else:
|
||||
noises_wav += noise_wav
|
||||
|
||||
# if all possible files is less than audio, choose other files
|
||||
if noises_wav is None:
|
||||
return self.additive_noise(noise_type, audio)
|
||||
|
||||
return audio + noises_wav
|
||||
|
||||
def reverberate(self, audio):
|
||||
audio_len = audio.shape[0]
|
||||
|
||||
rir_file = random.choice(self.rir_files)
|
||||
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
|
||||
rir = rir / np.sqrt(np.sum(rir**2))
|
||||
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
|
||||
|
||||
def apply_one(self, audio):
|
||||
noise_type = random.choice(self.global_noise_list)
|
||||
if noise_type == "RIR_AUG":
|
||||
return self.reverberate(audio)
|
||||
|
||||
return self.additive_noise(noise_type, audio)
|
||||
|
||||
|
||||
def setup_encoder_model(config: "Coqpit"):
|
||||
if config.model_params["model_name"].lower() == "lstm":
|
||||
model = LSTMSpeakerEncoder(
|
||||
config.model_params["input_dim"],
|
||||
config.model_params["proj_dim"],
|
||||
config.model_params["lstm_dim"],
|
||||
config.model_params["num_lstm_layers"],
|
||||
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||
audio_config=config.audio,
|
||||
)
|
||||
elif config.model_params["model_name"].lower() == "resnet":
|
||||
model = ResNetSpeakerEncoder(
|
||||
input_dim=config.model_params["input_dim"],
|
||||
proj_dim=config.model_params["proj_dim"],
|
||||
log_input=config.model_params.get("log_input", False),
|
||||
use_torch_spec=config.model_params.get("use_torch_spec", False),
|
||||
audio_config=config.audio,
|
||||
)
|
||||
return model
|
||||
@@ -0,0 +1,219 @@
|
||||
# coding=utf-8
|
||||
# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Only support eager mode and TF>=2.0.0
|
||||
# pylint: disable=no-member, invalid-name, relative-beyond-top-level
|
||||
# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
|
||||
""" voxceleb 1 & 2 """
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
import pandas
|
||||
import soundfile as sf
|
||||
from absl import logging
|
||||
|
||||
SUBSETS = {
|
||||
"vox1_dev_wav": [
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad",
|
||||
],
|
||||
"vox1_test_wav": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"],
|
||||
"vox2_dev_aac": [
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag",
|
||||
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah",
|
||||
],
|
||||
"vox2_test_aac": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"],
|
||||
}
|
||||
|
||||
MD5SUM = {
|
||||
"vox1_dev_wav": "ae63e55b951748cc486645f532ba230b",
|
||||
"vox2_dev_aac": "bbc063c46078a602ca71605645c2a402",
|
||||
"vox1_test_wav": "185fdc63c3c739954633d50379a3d102",
|
||||
"vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312",
|
||||
}
|
||||
|
||||
USER = {"user": "", "password": ""}
|
||||
|
||||
speaker_id_dict = {}
|
||||
|
||||
|
||||
def download_and_extract(directory, subset, urls):
|
||||
"""Download and extract the given split of dataset.
|
||||
|
||||
Args:
|
||||
directory: the directory where to put the downloaded data.
|
||||
subset: subset name of the corpus.
|
||||
urls: the list of urls to download the data file.
|
||||
"""
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
try:
|
||||
for url in urls:
|
||||
zip_filepath = os.path.join(directory, url.split("/")[-1])
|
||||
if os.path.exists(zip_filepath):
|
||||
continue
|
||||
logging.info("Downloading %s to %s" % (url, zip_filepath))
|
||||
subprocess.call(
|
||||
"wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
|
||||
shell=True,
|
||||
)
|
||||
|
||||
statinfo = os.stat(zip_filepath)
|
||||
logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
|
||||
|
||||
# concatenate all parts into zip files
|
||||
if ".zip" not in zip_filepath:
|
||||
zip_filepath = "_".join(zip_filepath.split("_")[:-1])
|
||||
subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True)
|
||||
zip_filepath += ".zip"
|
||||
extract_path = zip_filepath.strip(".zip")
|
||||
|
||||
# check zip file md5sum
|
||||
with open(zip_filepath, "rb") as f_zip:
|
||||
md5 = hashlib.md5(f_zip.read()).hexdigest()
|
||||
if md5 != MD5SUM[subset]:
|
||||
raise ValueError("md5sum of %s mismatch" % zip_filepath)
|
||||
|
||||
with zipfile.ZipFile(zip_filepath, "r") as zfile:
|
||||
zfile.extractall(directory)
|
||||
extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename)
|
||||
subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True)
|
||||
finally:
|
||||
# os.remove(zip_filepath)
|
||||
pass
|
||||
|
||||
|
||||
def exec_cmd(cmd):
|
||||
"""Run a command in a subprocess.
|
||||
Args:
|
||||
cmd: command line to be executed.
|
||||
Return:
|
||||
int, the return code.
|
||||
"""
|
||||
try:
|
||||
retcode = subprocess.call(cmd, shell=True)
|
||||
if retcode < 0:
|
||||
logging.info(f"Child was terminated by signal {retcode}")
|
||||
except OSError as e:
|
||||
logging.info(f"Execution failed: {e}")
|
||||
retcode = -999
|
||||
return retcode
|
||||
|
||||
|
||||
def decode_aac_with_ffmpeg(aac_file, wav_file):
|
||||
"""Decode a given AAC file into WAV using ffmpeg.
|
||||
Args:
|
||||
aac_file: file path to input AAC file.
|
||||
wav_file: file path to output WAV file.
|
||||
Return:
|
||||
bool, True if success.
|
||||
"""
|
||||
cmd = f"ffmpeg -i {aac_file} {wav_file}"
|
||||
logging.info(f"Decoding aac file using command line: {cmd}")
|
||||
ret = exec_cmd(cmd)
|
||||
if ret != 0:
|
||||
logging.error(f"Failed to decode aac file with retcode {ret}")
|
||||
logging.error("Please check your ffmpeg installation.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
|
||||
"""Optionally convert AAC to WAV and make speaker labels.
|
||||
Args:
|
||||
input_dir: the directory which holds the input dataset.
|
||||
subset: the name of the specified subset. e.g. vox1_dev_wav
|
||||
output_dir: the directory to place the newly generated csv files.
|
||||
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
|
||||
"""
|
||||
|
||||
logging.info("Preprocessing audio and label for subset %s" % subset)
|
||||
source_dir = os.path.join(input_dir, subset)
|
||||
|
||||
files = []
|
||||
# Convert all AAC file into WAV format. At the same time, generate the csv
|
||||
for root, _, filenames in os.walk(source_dir):
|
||||
for filename in filenames:
|
||||
name, ext = os.path.splitext(filename)
|
||||
if ext.lower() == ".wav":
|
||||
_, ext2 = os.path.splitext(name)
|
||||
if ext2:
|
||||
continue
|
||||
wav_file = os.path.join(root, filename)
|
||||
elif ext.lower() == ".m4a":
|
||||
# Convert AAC to WAV.
|
||||
aac_file = os.path.join(root, filename)
|
||||
wav_file = aac_file + ".wav"
|
||||
if not os.path.exists(wav_file):
|
||||
if not decode_aac_with_ffmpeg(aac_file, wav_file):
|
||||
raise RuntimeError("Audio decoding failed.")
|
||||
else:
|
||||
continue
|
||||
speaker_name = root.split(os.path.sep)[-2]
|
||||
if speaker_name not in speaker_id_dict:
|
||||
num = len(speaker_id_dict)
|
||||
speaker_id_dict[speaker_name] = num
|
||||
# wav_filesize = os.path.getsize(wav_file)
|
||||
wav_length = len(sf.read(wav_file)[0])
|
||||
files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name))
|
||||
|
||||
# Write to CSV file which contains four columns:
|
||||
# "wav_filename", "wav_length_ms", "speaker_id", "speaker_name".
|
||||
csv_file_path = os.path.join(output_dir, output_file)
|
||||
df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
|
||||
df.to_csv(csv_file_path, index=False, sep="\t")
|
||||
logging.info("Successfully generated csv file {}".format(csv_file_path))
|
||||
|
||||
|
||||
def processor(directory, subset, force_process):
|
||||
"""download and process"""
|
||||
urls = SUBSETS
|
||||
if subset not in urls:
|
||||
raise ValueError(subset, "is not in voxceleb")
|
||||
|
||||
subset_csv = os.path.join(directory, subset + ".csv")
|
||||
if not force_process and os.path.exists(subset_csv):
|
||||
return subset_csv
|
||||
|
||||
logging.info("Downloading and process the voxceleb in %s", directory)
|
||||
logging.info("Preparing subset %s", subset)
|
||||
download_and_extract(directory, subset, urls[subset])
|
||||
convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
|
||||
logging.info("Finished downloading and processing")
|
||||
return subset_csv
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.set_verbosity(logging.INFO)
|
||||
if len(sys.argv) != 4:
|
||||
print("Usage: python prepare_data.py save_directory user password")
|
||||
sys.exit()
|
||||
|
||||
DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
for SUBSET in SUBSETS:
|
||||
processor(DIR, SUBSET, False)
|
||||
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from coqpit import Coqpit
|
||||
from trainer import TrainerArgs, get_last_checkpoint
|
||||
from trainer.io import copy_model_files
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.utils.text.characters import parse_symbols
|
||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainArgs(TrainerArgs):
|
||||
config_path: str = field(default=None, metadata={"help": "Path to the config file."})
|
||||
|
||||
|
||||
def getarguments():
|
||||
train_config = TrainArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
||||
def process_args(args, config=None):
|
||||
"""Process parsed comand line arguments and initialize the config if not provided.
|
||||
Args:
|
||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
||||
Returns:
|
||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||
out_path (str): Path to save models and logging.
|
||||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
logging to the console.
|
||||
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
|
||||
TODO:
|
||||
- Interactive config definition.
|
||||
"""
|
||||
if isinstance(args, tuple):
|
||||
args, coqpit_overrides = args
|
||||
if args.continue_path:
|
||||
# continue a previous training from its output folder
|
||||
experiment_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
# init config if not already defined
|
||||
if config is None:
|
||||
if args.config_path:
|
||||
# init from a file
|
||||
config = load_config(args.config_path)
|
||||
else:
|
||||
# init from console args
|
||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
||||
|
||||
config_base = BaseTrainingConfig()
|
||||
config_base.parse_known_args(coqpit_overrides)
|
||||
config = register_config(config_base.model)()
|
||||
# override values from command-line args
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
config.output_log_path = experiment_path
|
||||
# setup rank 0 process in distributed training
|
||||
dashboard_logger = None
|
||||
if args.rank == 0:
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if config.has("characters") and config.characters is None:
|
||||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
dashboard_logger = logger_factory(config, experiment_path)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||
|
||||
|
||||
def init_arguments():
|
||||
train_config = TrainArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
||||
def init_training(config: Coqpit = None):
|
||||
"""Initialization of a training run."""
|
||||
parser = init_arguments()
|
||||
args = parser.parse_known_args()
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
|
||||
@@ -0,0 +1,50 @@
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import umap
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
|
||||
colormap = (
|
||||
np.array(
|
||||
[
|
||||
[76, 255, 0],
|
||||
[0, 127, 70],
|
||||
[255, 0, 0],
|
||||
[255, 217, 38],
|
||||
[0, 135, 255],
|
||||
[165, 0, 165],
|
||||
[255, 167, 255],
|
||||
[0, 255, 255],
|
||||
[255, 96, 38],
|
||||
[142, 76, 0],
|
||||
[33, 0, 127],
|
||||
[0, 0, 0],
|
||||
[183, 183, 183],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
/ 255
|
||||
)
|
||||
|
||||
|
||||
def plot_embeddings(embeddings, num_classes_in_batch):
|
||||
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
|
||||
|
||||
# if necessary get just the first 10 classes
|
||||
if num_classes_in_batch > 10:
|
||||
num_classes_in_batch = 10
|
||||
embeddings = embeddings[: num_classes_in_batch * num_utter_per_class]
|
||||
|
||||
model = umap.UMAP()
|
||||
projection = model.fit_transform(embeddings)
|
||||
ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class)
|
||||
colors = [colormap[i] for i in ground_truth]
|
||||
fig, ax = plt.subplots(figsize=(16, 10))
|
||||
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
|
||||
plt.gca().set_aspect("equal", "datalim")
|
||||
plt.title("UMAP projection")
|
||||
plt.tight_layout()
|
||||
plt.savefig("umap")
|
||||
return fig
|
||||
@@ -0,0 +1,59 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from trainer import TrainerModel
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseTrainerModel(TrainerModel):
|
||||
"""BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS.
|
||||
|
||||
Every new 🐸TTS model must inherit it.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
"""Init the model and all its attributes from the given config.
|
||||
|
||||
Override this depending on your model.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
||||
"""Forward pass for inference.
|
||||
|
||||
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
|
||||
is considered to be the main output and you can add any other auxiliary outputs as you want.
|
||||
|
||||
We don't use `*kwargs` since it is problematic with the TorchScript API.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): [description]
|
||||
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
"""
|
||||
outputs_dict = {"model_outputs": None}
|
||||
...
|
||||
return outputs_dict
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(
|
||||
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
|
||||
) -> None:
|
||||
"""Load a model checkpoint gile and get ready for training or inference.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
checkpoint_path (str): Path to the model checkpoint file.
|
||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||
"""
|
||||
...
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,485 @@
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import scipy
|
||||
import soundfile as sf
|
||||
from librosa import magphase, pyin
|
||||
|
||||
# For using kwargs
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
|
||||
def build_mel_basis(
|
||||
*,
|
||||
sample_rate: int = None,
|
||||
fft_size: int = None,
|
||||
num_mels: int = None,
|
||||
mel_fmax: int = None,
|
||||
mel_fmin: int = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Build melspectrogram basis.
|
||||
|
||||
Returns:
|
||||
np.ndarray: melspectrogram basis.
|
||||
"""
|
||||
if mel_fmax is not None:
|
||||
assert mel_fmax <= sample_rate // 2
|
||||
assert mel_fmax - mel_fmin > 0
|
||||
return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax)
|
||||
|
||||
|
||||
def millisec_to_length(
|
||||
*, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute hop and window length from milliseconds.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: hop length and window length for STFT.
|
||||
"""
|
||||
factor = frame_length_ms / frame_shift_ms
|
||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
||||
win_length = int(frame_length_ms / 1000.0 * sample_rate)
|
||||
hop_length = int(win_length / float(factor))
|
||||
return win_length, hop_length
|
||||
|
||||
|
||||
def _log(x, base):
|
||||
if base == 10:
|
||||
return np.log10(x)
|
||||
return np.log(x)
|
||||
|
||||
|
||||
def _exp(x, base):
|
||||
if base == 10:
|
||||
return np.power(10, x)
|
||||
return np.exp(x)
|
||||
|
||||
|
||||
def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert amplitude values to decibels.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Amplitude spectrogram.
|
||||
gain (float): Gain factor. Defaults to 1.
|
||||
base (int): Logarithm base. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decibels spectrogram.
|
||||
"""
|
||||
assert (x < 0).sum() == 0, " [!] Input values must be non-negative."
|
||||
return gain * _log(np.maximum(1e-8, x), base)
|
||||
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert decibels spectrogram to amplitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Decibels spectrogram.
|
||||
gain (float): Gain factor. Defaults to 1.
|
||||
base (int): Logarithm base. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Amplitude spectrogram.
|
||||
"""
|
||||
return _exp(x / gain, base)
|
||||
|
||||
|
||||
def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Audio signal.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Preemphasis coeff is set to 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decorrelated audio signal.
|
||||
"""
|
||||
if coef == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1, -coef], [1], x)
|
||||
|
||||
|
||||
def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||
"""Reverse pre-emphasis."""
|
||||
if coef == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1], [1, -coef], x)
|
||||
|
||||
|
||||
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
||||
Args:
|
||||
spec (np.ndarray): Normalized full scale linear spectrogram.
|
||||
|
||||
Shapes:
|
||||
- spec: :math:`[C, T]`
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized melspectrogram.
|
||||
"""
|
||||
return np.dot(mel_basis, spec)
|
||||
|
||||
|
||||
def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a melspectrogram to full scale spectrogram."""
|
||||
assert (mel < 0).sum() == 0, " [!] Input values must be non-negative."
|
||||
inv_mel_basis = np.linalg.pinv(mel_basis)
|
||||
return np.maximum(1e-10, np.dot(inv_mel_basis, mel))
|
||||
|
||||
|
||||
def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Compute a spectrogram from a waveform.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||
|
||||
Returns:
|
||||
np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length`
|
||||
"""
|
||||
D = stft(y=wav, **kwargs)
|
||||
S = np.abs(D)
|
||||
return S.astype(np.float32)
|
||||
|
||||
|
||||
def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray:
|
||||
"""Compute a melspectrogram from a waveform."""
|
||||
D = stft(y=wav, **kwargs)
|
||||
S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs)
|
||||
return S.astype(np.float32)
|
||||
|
||||
|
||||
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = spec.copy()
|
||||
return griffin_lim(spec=S**power, **kwargs)
|
||||
|
||||
|
||||
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = mel.copy()
|
||||
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
|
||||
return griffin_lim(spec=S**power, **kwargs)
|
||||
|
||||
|
||||
### STFT and ISTFT ###
|
||||
def stft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
fft_size: int = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
pad_mode: str = "reflect",
|
||||
window: str = "hann",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Librosa STFT wrapper.
|
||||
|
||||
Check http://librosa.org/doc/main/generated/librosa.stft.html argument details.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.stft(
|
||||
y=y,
|
||||
n_fft=fft_size,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
pad_mode=pad_mode,
|
||||
window=window,
|
||||
center=center,
|
||||
)
|
||||
|
||||
|
||||
def istft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
window: str = "hann",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Librosa iSTFT wrapper.
|
||||
|
||||
Check http://librosa.org/doc/main/generated/librosa.istft.html argument details.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window)
|
||||
|
||||
|
||||
def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray:
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*spec.shape))
|
||||
S_complex = np.abs(spec).astype(complex)
|
||||
y = istft(y=S_complex * angles, **kwargs)
|
||||
if not np.isfinite(y).all():
|
||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
||||
return np.array([0.0])
|
||||
for _ in range(num_iter):
|
||||
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))
|
||||
y = istft(y=S_complex * angles, **kwargs)
|
||||
return y
|
||||
|
||||
|
||||
def compute_stft_paddings(
|
||||
*, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
||||
(first and final frames)"""
|
||||
pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0]
|
||||
if not pad_two_sides:
|
||||
return 0, pad
|
||||
return pad // 2, pad // 2 + pad % 2
|
||||
|
||||
|
||||
def compute_f0(
|
||||
*,
|
||||
x: np.ndarray = None,
|
||||
pitch_fmax: float = None,
|
||||
pitch_fmin: float = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
sample_rate: int = None,
|
||||
stft_pad_mode: str = "reflect",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||
pitch_fmax (float): Pitch max value.
|
||||
pitch_fmin (float): Pitch min value.
|
||||
hop_length (int): Number of frames between STFT columns.
|
||||
win_length (int): STFT window length.
|
||||
sample_rate (int): Audio sampling rate.
|
||||
stft_pad_mode (str): Padding mode for STFT.
|
||||
center (bool): Centered padding.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
|
||||
|
||||
Examples:
|
||||
>>> WAV_FILE = filename = librosa.example('vibeace')
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio import AudioProcessor
|
||||
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||
>>> pitch = ap.compute_f0(wav)
|
||||
"""
|
||||
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||
assert pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
|
||||
|
||||
f0, voiced_mask, _ = pyin(
|
||||
y=x.astype(np.double),
|
||||
fmin=pitch_fmin,
|
||||
fmax=pitch_fmax,
|
||||
sr=sample_rate,
|
||||
frame_length=win_length,
|
||||
win_length=win_length // 2,
|
||||
hop_length=hop_length,
|
||||
pad_mode=stft_pad_mode,
|
||||
center=center,
|
||||
n_thresholds=100,
|
||||
beta_parameters=(2, 18),
|
||||
boltzmann_parameter=2,
|
||||
resolution=0.1,
|
||||
max_transition_rate=35.92,
|
||||
switch_prob=0.01,
|
||||
no_trough_prob=0.01,
|
||||
)
|
||||
f0[~voiced_mask] = 0.0
|
||||
|
||||
return f0
|
||||
|
||||
|
||||
def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray:
|
||||
"""Compute energy of a waveform using the same parameters used for computing melspectrogram.
|
||||
Args:
|
||||
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||
Returns:
|
||||
np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length`
|
||||
Examples:
|
||||
>>> WAV_FILE = filename = librosa.example('vibeace')
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio import AudioProcessor
|
||||
>>> conf = BaseAudioConfig()
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||
>>> energy = ap.compute_energy(wav)
|
||||
"""
|
||||
x = stft(y=y, **kwargs)
|
||||
mag, _ = magphase(x)
|
||||
energy = np.sqrt(np.sum(mag**2, axis=0))
|
||||
return energy
|
||||
|
||||
|
||||
### Audio Processing ###
|
||||
def find_endpoint(
|
||||
*,
|
||||
wav: np.ndarray = None,
|
||||
trim_db: float = -40,
|
||||
sample_rate: int = None,
|
||||
min_silence_sec=0.8,
|
||||
gain: float = None,
|
||||
base: int = None,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""Find the last point without silence at the end of a audio signal.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Audio signal.
|
||||
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
|
||||
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
|
||||
gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None.
|
||||
base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
int: Last point without silence.
|
||||
"""
|
||||
window_length = int(sample_rate * min_silence_sec)
|
||||
hop_length = int(window_length / 4)
|
||||
threshold = db_to_amp(x=-trim_db, gain=gain, base=base)
|
||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||
if np.max(wav[x : x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
|
||||
|
||||
def trim_silence(
|
||||
*,
|
||||
wav: np.ndarray = None,
|
||||
sample_rate: int = None,
|
||||
trim_db: float = None,
|
||||
win_length: int = None,
|
||||
hop_length: int = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||
margin = int(sample_rate * 0.01)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0]
|
||||
|
||||
|
||||
def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray:
|
||||
"""Normalize the volume of an audio signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
coef (float): Coefficient to rescale the maximum value. Defaults to 0.95.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Volume normalized waveform.
|
||||
"""
|
||||
return x / abs(x).max() * coef
|
||||
|
||||
|
||||
def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||
r = 10 ** (db_level / 20)
|
||||
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
||||
return wav * a
|
||||
|
||||
|
||||
def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||
"""Normalize the volume based on RMS of the signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
db_level (float): Target dB level in RMS. Defaults to -27.0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: RMS normalized waveform.
|
||||
"""
|
||||
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
||||
wav = rms_norm(wav=x, db_level=db_level)
|
||||
return wav
|
||||
|
||||
|
||||
def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray:
|
||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||
|
||||
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||
resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if resample:
|
||||
# loading with resampling. It is significantly slower.
|
||||
x, _ = librosa.load(filename, sr=sample_rate)
|
||||
else:
|
||||
# SF is faster than librosa for loading files
|
||||
x, _ = sf.read(filename)
|
||||
return x
|
||||
|
||||
|
||||
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None:
|
||||
"""Save float waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
|
||||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
|
||||
wav_norm = wav_norm.astype(np.int16)
|
||||
if pipe_out:
|
||||
wav_buffer = BytesIO()
|
||||
scipy.io.wavfile.write(wav_buffer, sample_rate, wav_norm)
|
||||
wav_buffer.seek(0)
|
||||
pipe_out.buffer.write(wav_buffer.read())
|
||||
scipy.io.wavfile.write(path, sample_rate, wav_norm)
|
||||
|
||||
|
||||
def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||
mu = 2**mulaw_qc - 1
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
signal = (signal + 1) / 2 * mu + 0.5
|
||||
return np.floor(
|
||||
signal,
|
||||
)
|
||||
|
||||
|
||||
def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||
"""Recovers waveform from quantized values."""
|
||||
mu = 2**mulaw_qc - 1
|
||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||
return x
|
||||
|
||||
|
||||
def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray:
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
|
||||
|
||||
def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray:
|
||||
"""Quantize a waveform to a given number of bits.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
||||
quantize_bits (int): Number of quantization bits.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Quantized waveform.
|
||||
"""
|
||||
return (x + 1.0) * (2**quantize_bits - 1) / 2
|
||||
|
||||
|
||||
def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray:
|
||||
"""Dequantize a waveform from the given number of bits."""
|
||||
return 2 * x / (2**quantize_bits - 1) - 1
|
||||
@@ -0,0 +1,633 @@
|
||||
from io import BytesIO
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
import scipy.signal
|
||||
|
||||
from TTS.tts.utils.helpers import StandardScaler
|
||||
from TTS.utils.audio.numpy_transforms import (
|
||||
amp_to_db,
|
||||
build_mel_basis,
|
||||
compute_f0,
|
||||
db_to_amp,
|
||||
deemphasis,
|
||||
find_endpoint,
|
||||
griffin_lim,
|
||||
load_wav,
|
||||
mel_to_spec,
|
||||
millisec_to_length,
|
||||
preemphasis,
|
||||
rms_volume_norm,
|
||||
spec_to_mel,
|
||||
stft,
|
||||
trim_silence,
|
||||
volume_norm,
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
|
||||
|
||||
class AudioProcessor(object):
|
||||
"""Audio Processor for TTS.
|
||||
|
||||
Note:
|
||||
All the class arguments are set to default values to enable a flexible initialization
|
||||
of the class with the model config. They are not meaningful for all the arguments.
|
||||
|
||||
Args:
|
||||
sample_rate (int, optional):
|
||||
target audio sampling rate. Defaults to None.
|
||||
|
||||
resample (bool, optional):
|
||||
enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
||||
|
||||
num_mels (int, optional):
|
||||
number of melspectrogram dimensions. Defaults to None.
|
||||
|
||||
log_func (int, optional):
|
||||
log exponent used for converting spectrogram aplitude to DB.
|
||||
|
||||
min_level_db (int, optional):
|
||||
minimum db threshold for the computed melspectrograms. Defaults to None.
|
||||
|
||||
frame_shift_ms (int, optional):
|
||||
milliseconds of frames between STFT columns. Defaults to None.
|
||||
|
||||
frame_length_ms (int, optional):
|
||||
milliseconds of STFT window length. Defaults to None.
|
||||
|
||||
hop_length (int, optional):
|
||||
number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
|
||||
|
||||
win_length (int, optional):
|
||||
STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
|
||||
|
||||
ref_level_db (int, optional):
|
||||
reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
|
||||
|
||||
fft_size (int, optional):
|
||||
FFT window size for STFT. Defaults to 1024.
|
||||
|
||||
power (int, optional):
|
||||
Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
|
||||
|
||||
preemphasis (float, optional):
|
||||
Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
|
||||
|
||||
signal_norm (bool, optional):
|
||||
enable/disable signal normalization. Defaults to None.
|
||||
|
||||
symmetric_norm (bool, optional):
|
||||
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
|
||||
|
||||
max_norm (float, optional):
|
||||
```k``` defining the normalization range. Defaults to None.
|
||||
|
||||
mel_fmin (int, optional):
|
||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
pitch_fmin (int, optional):
|
||||
minimum filter frequency for computing pitch. Defaults to None.
|
||||
|
||||
pitch_fmax (int, optional):
|
||||
maximum filter frequency for computing pitch. Defaults to None.
|
||||
|
||||
spec_gain (int, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 20.
|
||||
|
||||
stft_pad_mode (str, optional):
|
||||
Padding mode for STFT. Defaults to 'reflect'.
|
||||
|
||||
clip_norm (bool, optional):
|
||||
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||
|
||||
griffin_lim_iters (int, optional):
|
||||
Number of GriffinLim iterations. Defaults to None.
|
||||
|
||||
do_trim_silence (bool, optional):
|
||||
enable/disable silence trimming when loading the audio signal. Defaults to False.
|
||||
|
||||
trim_db (int, optional):
|
||||
DB threshold used for silence trimming. Defaults to 60.
|
||||
|
||||
do_sound_norm (bool, optional):
|
||||
enable/disable signal normalization. Defaults to False.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||
|
||||
do_amp_to_db_mel (bool, optional):
|
||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||
|
||||
do_rms_norm (bool, optional):
|
||||
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
|
||||
|
||||
db_level (int, optional):
|
||||
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
|
||||
|
||||
stats_path (str, optional):
|
||||
Path to the computed stats file. Defaults to None.
|
||||
|
||||
verbose (bool, optional):
|
||||
enable/disable logging. Defaults to True.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=None,
|
||||
resample=False,
|
||||
num_mels=None,
|
||||
log_func="np.log10",
|
||||
min_level_db=None,
|
||||
frame_shift_ms=None,
|
||||
frame_length_ms=None,
|
||||
hop_length=None,
|
||||
win_length=None,
|
||||
ref_level_db=None,
|
||||
fft_size=1024,
|
||||
power=None,
|
||||
preemphasis=0.0,
|
||||
signal_norm=None,
|
||||
symmetric_norm=None,
|
||||
max_norm=None,
|
||||
mel_fmin=None,
|
||||
mel_fmax=None,
|
||||
pitch_fmax=None,
|
||||
pitch_fmin=None,
|
||||
spec_gain=20,
|
||||
stft_pad_mode="reflect",
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=None,
|
||||
do_trim_silence=False,
|
||||
trim_db=60,
|
||||
do_sound_norm=False,
|
||||
do_amp_to_db_linear=True,
|
||||
do_amp_to_db_mel=True,
|
||||
do_rms_norm=False,
|
||||
db_level=None,
|
||||
stats_path=None,
|
||||
verbose=True,
|
||||
**_,
|
||||
):
|
||||
# setup class attributed
|
||||
self.sample_rate = sample_rate
|
||||
self.resample = resample
|
||||
self.num_mels = num_mels
|
||||
self.log_func = log_func
|
||||
self.min_level_db = min_level_db or 0
|
||||
self.frame_shift_ms = frame_shift_ms
|
||||
self.frame_length_ms = frame_length_ms
|
||||
self.ref_level_db = ref_level_db
|
||||
self.fft_size = fft_size
|
||||
self.power = power
|
||||
self.preemphasis = preemphasis
|
||||
self.griffin_lim_iters = griffin_lim_iters
|
||||
self.signal_norm = signal_norm
|
||||
self.symmetric_norm = symmetric_norm
|
||||
self.mel_fmin = mel_fmin or 0
|
||||
self.mel_fmax = mel_fmax
|
||||
self.pitch_fmin = pitch_fmin
|
||||
self.pitch_fmax = pitch_fmax
|
||||
self.spec_gain = float(spec_gain)
|
||||
self.stft_pad_mode = stft_pad_mode
|
||||
self.max_norm = 1.0 if max_norm is None else float(max_norm)
|
||||
self.clip_norm = clip_norm
|
||||
self.do_trim_silence = do_trim_silence
|
||||
self.trim_db = trim_db
|
||||
self.do_sound_norm = do_sound_norm
|
||||
self.do_amp_to_db_linear = do_amp_to_db_linear
|
||||
self.do_amp_to_db_mel = do_amp_to_db_mel
|
||||
self.do_rms_norm = do_rms_norm
|
||||
self.db_level = db_level
|
||||
self.stats_path = stats_path
|
||||
# setup exp_func for db to amp conversion
|
||||
if log_func == "np.log":
|
||||
self.base = np.e
|
||||
elif log_func == "np.log10":
|
||||
self.base = 10
|
||||
else:
|
||||
raise ValueError(" [!] unknown `log_func` value.")
|
||||
# setup stft parameters
|
||||
if hop_length is None:
|
||||
# compute stft parameters from given time values
|
||||
self.win_length, self.hop_length = millisec_to_length(
|
||||
frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate
|
||||
)
|
||||
else:
|
||||
# use stft parameters from config file
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
assert min_level_db != 0.0, " [!] min_level_db is 0"
|
||||
assert (
|
||||
self.win_length <= self.fft_size
|
||||
), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
|
||||
members = vars(self)
|
||||
if verbose:
|
||||
print(" > Setting up Audio Processor...")
|
||||
for key, value in members.items():
|
||||
print(" | > {}:{}".format(key, value))
|
||||
# create spectrogram utils
|
||||
self.mel_basis = build_mel_basis(
|
||||
sample_rate=self.sample_rate,
|
||||
fft_size=self.fft_size,
|
||||
num_mels=self.num_mels,
|
||||
mel_fmax=self.mel_fmax,
|
||||
mel_fmin=self.mel_fmin,
|
||||
)
|
||||
# setup scaler
|
||||
if stats_path and signal_norm:
|
||||
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
||||
self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)
|
||||
self.signal_norm = True
|
||||
self.max_norm = None
|
||||
self.clip_norm = None
|
||||
self.symmetric_norm = None
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit", verbose=True):
|
||||
if "audio" in config:
|
||||
return AudioProcessor(verbose=verbose, **config.audio)
|
||||
return AudioProcessor(verbose=verbose, **config)
|
||||
|
||||
### normalization ###
|
||||
def normalize(self, S: np.ndarray) -> np.ndarray:
|
||||
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
|
||||
|
||||
Args:
|
||||
S (np.ndarray): Spectrogram to normalize.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Mean and variance is computed from incompatible parameters.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized spectrogram.
|
||||
"""
|
||||
# pylint: disable=no-else-return
|
||||
S = S.copy()
|
||||
if self.signal_norm:
|
||||
# mean-var scaling
|
||||
if hasattr(self, "mel_scaler"):
|
||||
if S.shape[0] == self.num_mels:
|
||||
return self.mel_scaler.transform(S.T).T
|
||||
elif S.shape[0] == self.fft_size / 2:
|
||||
return self.linear_scaler.transform(S.T).T
|
||||
else:
|
||||
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
|
||||
# range normalization
|
||||
S -= self.ref_level_db # discard certain range of DB assuming it is air noise
|
||||
S_norm = (S - self.min_level_db) / (-self.min_level_db)
|
||||
if self.symmetric_norm:
|
||||
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
|
||||
if self.clip_norm:
|
||||
S_norm = np.clip(
|
||||
S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
|
||||
)
|
||||
return S_norm
|
||||
else:
|
||||
S_norm = self.max_norm * S_norm
|
||||
if self.clip_norm:
|
||||
S_norm = np.clip(S_norm, 0, self.max_norm)
|
||||
return S_norm
|
||||
else:
|
||||
return S
|
||||
|
||||
def denormalize(self, S: np.ndarray) -> np.ndarray:
|
||||
"""Denormalize spectrogram values.
|
||||
|
||||
Args:
|
||||
S (np.ndarray): Spectrogram to denormalize.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Mean and variance are incompatible.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Denormalized spectrogram.
|
||||
"""
|
||||
# pylint: disable=no-else-return
|
||||
S_denorm = S.copy()
|
||||
if self.signal_norm:
|
||||
# mean-var scaling
|
||||
if hasattr(self, "mel_scaler"):
|
||||
if S_denorm.shape[0] == self.num_mels:
|
||||
return self.mel_scaler.inverse_transform(S_denorm.T).T
|
||||
elif S_denorm.shape[0] == self.fft_size / 2:
|
||||
return self.linear_scaler.inverse_transform(S_denorm.T).T
|
||||
else:
|
||||
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
|
||||
if self.symmetric_norm:
|
||||
if self.clip_norm:
|
||||
S_denorm = np.clip(
|
||||
S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
|
||||
)
|
||||
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
|
||||
return S_denorm + self.ref_level_db
|
||||
else:
|
||||
if self.clip_norm:
|
||||
S_denorm = np.clip(S_denorm, 0, self.max_norm)
|
||||
S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db
|
||||
return S_denorm + self.ref_level_db
|
||||
else:
|
||||
return S_denorm
|
||||
|
||||
### Mean-STD scaling ###
|
||||
def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]:
|
||||
"""Loading mean and variance statistics from a `npy` file.
|
||||
|
||||
Args:
|
||||
stats_path (str): Path to the `npy` file containing
|
||||
|
||||
Returns:
|
||||
Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to
|
||||
compute them.
|
||||
"""
|
||||
stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg
|
||||
mel_mean = stats["mel_mean"]
|
||||
mel_std = stats["mel_std"]
|
||||
linear_mean = stats["linear_mean"]
|
||||
linear_std = stats["linear_std"]
|
||||
stats_config = stats["audio_config"]
|
||||
# check all audio parameters used for computing stats
|
||||
skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"]
|
||||
for key in stats_config.keys():
|
||||
if key in skip_parameters:
|
||||
continue
|
||||
if key not in ["sample_rate", "trim_db"]:
|
||||
assert (
|
||||
stats_config[key] == self.__dict__[key]
|
||||
), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
def setup_scaler(
|
||||
self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray
|
||||
) -> None:
|
||||
"""Initialize scaler objects used in mean-std normalization.
|
||||
|
||||
Args:
|
||||
mel_mean (np.ndarray): Mean for melspectrograms.
|
||||
mel_std (np.ndarray): STD for melspectrograms.
|
||||
linear_mean (np.ndarray): Mean for full scale spectrograms.
|
||||
linear_std (np.ndarray): STD for full scale spectrograms.
|
||||
"""
|
||||
self.mel_scaler = StandardScaler()
|
||||
self.mel_scaler.set_stats(mel_mean, mel_std)
|
||||
self.linear_scaler = StandardScaler()
|
||||
self.linear_scaler.set_stats(linear_mean, linear_std)
|
||||
|
||||
### Preemphasis ###
|
||||
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Audio signal.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Preemphasis coeff is set to 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decorrelated audio signal.
|
||||
"""
|
||||
return preemphasis(x=x, coef=self.preemphasis)
|
||||
|
||||
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reverse pre-emphasis."""
|
||||
return deemphasis(x=x, coef=self.preemphasis)
|
||||
|
||||
### SPECTROGRAMs ###
|
||||
def spectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a spectrogram from a waveform.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Spectrogram.
|
||||
"""
|
||||
if self.preemphasis != 0:
|
||||
y = self.apply_preemphasis(y)
|
||||
D = stft(
|
||||
y=y,
|
||||
fft_size=self.fft_size,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
if self.do_amp_to_db_linear:
|
||||
S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
|
||||
else:
|
||||
S = np.abs(D)
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a melspectrogram from a waveform."""
|
||||
if self.preemphasis != 0:
|
||||
y = self.apply_preemphasis(y)
|
||||
D = stft(
|
||||
y=y,
|
||||
fft_size=self.fft_size,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis)
|
||||
if self.do_amp_to_db_mel:
|
||||
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = self.denormalize(spectrogram)
|
||||
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||
# Reconstruct phase
|
||||
W = self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||
|
||||
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
D = self.denormalize(mel_spectrogram)
|
||||
S = db_to_amp(x=D, gain=self.spec_gain, base=self.base)
|
||||
S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear
|
||||
W = self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||
|
||||
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
||||
Args:
|
||||
linear_spec (np.ndarray): Normalized full scale linear spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized melspectrogram.
|
||||
"""
|
||||
S = self.denormalize(linear_spec)
|
||||
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||
S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
|
||||
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||
mel = self.normalize(S)
|
||||
return mel
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
return griffin_lim(
|
||||
spec=S,
|
||||
num_iter=self.griffin_lim_iters,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
fft_size=self.fft_size,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
|
||||
def compute_f0(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Pitch.
|
||||
|
||||
Examples:
|
||||
>>> WAV_FILE = filename = librosa.example('vibeace')
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio import AudioProcessor
|
||||
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||
>>> pitch = ap.compute_f0(wav)
|
||||
"""
|
||||
# align F0 length to the spectrogram length
|
||||
if len(x) % self.hop_length == 0:
|
||||
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
|
||||
|
||||
f0 = compute_f0(
|
||||
x=x,
|
||||
pitch_fmax=self.pitch_fmax,
|
||||
pitch_fmin=self.pitch_fmin,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
sample_rate=self.sample_rate,
|
||||
stft_pad_mode=self.stft_pad_mode,
|
||||
center=True,
|
||||
)
|
||||
|
||||
return f0
|
||||
|
||||
### Audio Processing ###
|
||||
def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int:
|
||||
"""Find the last point without silence at the end of a audio signal.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Audio signal.
|
||||
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
|
||||
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
int: Last point without silence.
|
||||
"""
|
||||
return find_endpoint(
|
||||
wav=wav,
|
||||
trim_db=self.trim_db,
|
||||
sample_rate=self.sample_rate,
|
||||
min_silence_sec=min_silence_sec,
|
||||
gain=self.spec_gain,
|
||||
base=self.base,
|
||||
)
|
||||
|
||||
def trim_silence(self, wav):
|
||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||
return trim_silence(
|
||||
wav=wav,
|
||||
sample_rate=self.sample_rate,
|
||||
trim_db=self.trim_db,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sound_norm(x: np.ndarray) -> np.ndarray:
|
||||
"""Normalize the volume of an audio signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Volume normalized waveform.
|
||||
"""
|
||||
return volume_norm(x=x)
|
||||
|
||||
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
||||
"""Normalize the volume based on RMS of the signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: RMS normalized waveform.
|
||||
"""
|
||||
if db_level is None:
|
||||
db_level = self.db_level
|
||||
return rms_volume_norm(x=x, db_level=db_level)
|
||||
|
||||
### save and load ###
|
||||
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||
|
||||
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if sr is not None:
|
||||
x = load_wav(filename=filename, sample_rate=sr, resample=True)
|
||||
else:
|
||||
x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample)
|
||||
if self.do_trim_silence:
|
||||
try:
|
||||
x = self.trim_silence(x)
|
||||
except ValueError:
|
||||
print(f" [!] File cannot be trimmed for silence - {filename}")
|
||||
if self.do_sound_norm:
|
||||
x = self.sound_norm(x)
|
||||
if self.do_rms_norm:
|
||||
x = self.rms_volume_norm(x, self.db_level)
|
||||
return x
|
||||
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None:
|
||||
"""Save a waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform to save.
|
||||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
if self.do_rms_norm:
|
||||
wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767
|
||||
else:
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
|
||||
wav_norm = wav_norm.astype(np.int16)
|
||||
if pipe_out:
|
||||
wav_buffer = BytesIO()
|
||||
scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm)
|
||||
wav_buffer.seek(0)
|
||||
pipe_out.buffer.write(wav_buffer.read())
|
||||
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm)
|
||||
|
||||
def get_duration(self, filename: str) -> float:
|
||||
"""Get the duration of a wav file using Librosa.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
"""
|
||||
return librosa.get_duration(filename=filename)
|
||||
@@ -0,0 +1,165 @@
|
||||
import librosa
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
||||
Args:
|
||||
|
||||
n_fft (int):
|
||||
FFT window size for STFT.
|
||||
|
||||
hop_length (int):
|
||||
number of frames between STFT columns.
|
||||
|
||||
win_length (int, optional):
|
||||
STFT window length.
|
||||
|
||||
pad_wav (bool, optional):
|
||||
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
||||
|
||||
window (str, optional):
|
||||
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
||||
|
||||
sample_rate (int, optional):
|
||||
target audio sampling rate. Defaults to None.
|
||||
|
||||
mel_fmin (int, optional):
|
||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
n_mels (int, optional):
|
||||
number of melspectrogram dimensions. Defaults to None.
|
||||
|
||||
use_mel (bool, optional):
|
||||
If True compute the melspectrograms otherwise. Defaults to False.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
||||
|
||||
spec_gain (float, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 1.0.
|
||||
|
||||
power (float, optional):
|
||||
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
||||
|
||||
use_htk (bool, optional):
|
||||
Use HTK formula in mel filter instead of Slaney.
|
||||
|
||||
mel_norm (None, 'slaney', or number, optional):
|
||||
If 'slaney', divide the triangular mel weights by the width of the mel band
|
||||
(area normalization).
|
||||
|
||||
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
||||
See `librosa.util.normalize` for a full description of supported norm values
|
||||
(including `+-np.inf`).
|
||||
|
||||
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
do_amp_to_db=False,
|
||||
spec_gain=1.0,
|
||||
power=None,
|
||||
use_htk=False,
|
||||
mel_norm="slaney",
|
||||
normalized=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_wav = pad_wav
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.do_amp_to_db = do_amp_to_db
|
||||
self.spec_gain = spec_gain
|
||||
self.power = power
|
||||
self.use_htk = use_htk
|
||||
self.mel_norm = mel_norm
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
self.normalized = normalized
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute spectrogram frames by torch based stft.
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform
|
||||
|
||||
Returns:
|
||||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [:math:`[B, 1, T]`]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=self.normalized,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
||||
|
||||
if self.power is not None:
|
||||
S = S**self.power
|
||||
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
if self.do_amp_to_db:
|
||||
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
sr=self.sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax,
|
||||
htk=self.use_htk,
|
||||
norm=self.mel_norm,
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
@staticmethod
|
||||
def _amp_to_db(x, spec_gain=1.0):
|
||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||
|
||||
@staticmethod
|
||||
def _db_to_amp(x, spec_gain=1.0):
|
||||
return torch.exp(x) / spec_gain
|
||||
@@ -0,0 +1,105 @@
|
||||
class TrainerCallback:
|
||||
@staticmethod
|
||||
def on_init_start(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_init_start"):
|
||||
trainer.model.module.on_init_start(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_init_start"):
|
||||
trainer.model.on_init_start(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_init_start"):
|
||||
trainer.criterion.on_init_start(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_init_start"):
|
||||
trainer.optimizer.on_init_start(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_init_end(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_init_end"):
|
||||
trainer.model.module.on_init_end(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_init_end"):
|
||||
trainer.model.on_init_end(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_init_end"):
|
||||
trainer.criterion.on_init_end(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_init_end"):
|
||||
trainer.optimizer.on_init_end(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_epoch_start(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_epoch_start"):
|
||||
trainer.model.module.on_epoch_start(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_epoch_start"):
|
||||
trainer.model.on_epoch_start(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_epoch_start"):
|
||||
trainer.criterion.on_epoch_start(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_epoch_start"):
|
||||
trainer.optimizer.on_epoch_start(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_epoch_end(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_epoch_end"):
|
||||
trainer.model.module.on_epoch_end(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_epoch_end"):
|
||||
trainer.model.on_epoch_end(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_epoch_end"):
|
||||
trainer.criterion.on_epoch_end(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_epoch_end"):
|
||||
trainer.optimizer.on_epoch_end(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_train_step_start(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_train_step_start"):
|
||||
trainer.model.module.on_train_step_start(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_train_step_start"):
|
||||
trainer.model.on_train_step_start(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_train_step_start"):
|
||||
trainer.criterion.on_train_step_start(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_train_step_start"):
|
||||
trainer.optimizer.on_train_step_start(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_train_step_end(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_train_step_end"):
|
||||
trainer.model.module.on_train_step_end(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_train_step_end"):
|
||||
trainer.model.on_train_step_end(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_train_step_end"):
|
||||
trainer.criterion.on_train_step_end(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_train_step_end"):
|
||||
trainer.optimizer.on_train_step_end(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_keyboard_interrupt(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_keyboard_interrupt"):
|
||||
trainer.model.module.on_keyboard_interrupt(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_keyboard_interrupt"):
|
||||
trainer.model.on_keyboard_interrupt(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_keyboard_interrupt"):
|
||||
trainer.criterion.on_keyboard_interrupt(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
|
||||
trainer.optimizer.on_keyboard_interrupt(trainer)
|
||||
@@ -0,0 +1,67 @@
|
||||
from typing import Generator
|
||||
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
|
||||
|
||||
class CapacitronOptimizer:
|
||||
"""Double optimizer class for the Capacitron model."""
|
||||
|
||||
def __init__(self, config: dict, model_params: Generator) -> None:
|
||||
self.primary_params, self.secondary_params = self.split_model_parameters(model_params)
|
||||
|
||||
optimizer_names = list(config.optimizer_params.keys())
|
||||
optimizer_parameters = list(config.optimizer_params.values())
|
||||
|
||||
self.primary_optimizer = get_optimizer(
|
||||
optimizer_names[0],
|
||||
optimizer_parameters[0],
|
||||
config.lr,
|
||||
parameters=self.primary_params,
|
||||
)
|
||||
|
||||
self.secondary_optimizer = get_optimizer(
|
||||
optimizer_names[1],
|
||||
self.extract_optimizer_parameters(optimizer_parameters[1]),
|
||||
optimizer_parameters[1]["lr"],
|
||||
parameters=self.secondary_params,
|
||||
)
|
||||
|
||||
self.param_groups = self.primary_optimizer.param_groups
|
||||
|
||||
def first_step(self):
|
||||
self.secondary_optimizer.step()
|
||||
self.secondary_optimizer.zero_grad()
|
||||
self.primary_optimizer.zero_grad()
|
||||
|
||||
def step(self):
|
||||
# Update param groups to display the correct learning rate
|
||||
self.param_groups = self.primary_optimizer.param_groups
|
||||
self.primary_optimizer.step()
|
||||
|
||||
def zero_grad(self, set_to_none=False):
|
||||
self.primary_optimizer.zero_grad(set_to_none)
|
||||
self.secondary_optimizer.zero_grad(set_to_none)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.primary_optimizer.load_state_dict(state_dict[0])
|
||||
self.secondary_optimizer.load_state_dict(state_dict[1])
|
||||
|
||||
def state_dict(self):
|
||||
return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()]
|
||||
|
||||
@staticmethod
|
||||
def split_model_parameters(model_params: Generator) -> list:
|
||||
primary_params = []
|
||||
secondary_params = []
|
||||
for name, param in model_params:
|
||||
if param.requires_grad:
|
||||
if name == "capacitron_vae_layer.beta":
|
||||
secondary_params.append(param)
|
||||
else:
|
||||
primary_params.append(param)
|
||||
return [iter(primary_params), iter(secondary_params)]
|
||||
|
||||
@staticmethod
|
||||
def extract_optimizer_parameters(params: dict) -> dict:
|
||||
"""Extract parameters that are not the learning rate"""
|
||||
return {k: v for k, v in params.items() if k != "lr"}
|
||||
@@ -0,0 +1,20 @@
|
||||
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def reduce_tensor(tensor, num_gpus):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
||||
rt /= num_gpus
|
||||
return rt
|
||||
|
||||
|
||||
def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
|
||||
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
||||
|
||||
# Set cuda device so everything is done on the right GPU.
|
||||
torch.cuda.set_device(rank % torch.cuda.device_count())
|
||||
|
||||
# Initialize distributed communication
|
||||
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
|
||||
@@ -0,0 +1,206 @@
|
||||
# Adapted from https://github.com/pytorch/audio/
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
import urllib
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from os.path import expanduser
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from torch.utils.model_zoo import tqdm
|
||||
|
||||
|
||||
def stream_url(
|
||||
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
|
||||
) -> Iterable:
|
||||
"""Stream url by chunk
|
||||
|
||||
Args:
|
||||
url (str): Url.
|
||||
start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
|
||||
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
|
||||
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||||
"""
|
||||
|
||||
# If we already have the whole file, there is no need to download it again
|
||||
req = urllib.request.Request(url, method="HEAD")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
url_size = int(response.info().get("Content-Length", -1))
|
||||
if url_size == start_byte:
|
||||
return
|
||||
|
||||
req = urllib.request.Request(url)
|
||||
if start_byte:
|
||||
req.headers["Range"] = "bytes={}-".format(start_byte)
|
||||
|
||||
with urllib.request.urlopen(req) as upointer, tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=url_size,
|
||||
disable=not progress_bar,
|
||||
) as pbar:
|
||||
num_bytes = 0
|
||||
while True:
|
||||
chunk = upointer.read(block_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
num_bytes += len(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
|
||||
def download_url(
|
||||
url: str,
|
||||
download_folder: str,
|
||||
filename: Optional[str] = None,
|
||||
hash_value: Optional[str] = None,
|
||||
hash_type: str = "sha256",
|
||||
progress_bar: bool = True,
|
||||
resume: bool = False,
|
||||
) -> None:
|
||||
"""Download file to disk.
|
||||
|
||||
Args:
|
||||
url (str): Url.
|
||||
download_folder (str): Folder to download file.
|
||||
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
|
||||
(Default: ``None``).
|
||||
hash_value (str or None, optional): Hash for url (Default: ``None``).
|
||||
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||||
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||||
resume (bool, optional): Enable resuming download (Default: ``False``).
|
||||
"""
|
||||
|
||||
req = urllib.request.Request(url, method="HEAD")
|
||||
req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with
|
||||
|
||||
# Detect filename
|
||||
filename = filename or req_info.get_filename() or os.path.basename(url)
|
||||
filepath = os.path.join(download_folder, filename)
|
||||
if resume and os.path.exists(filepath):
|
||||
mode = "ab"
|
||||
local_size: Optional[int] = os.path.getsize(filepath)
|
||||
|
||||
elif not resume and os.path.exists(filepath):
|
||||
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
|
||||
else:
|
||||
mode = "wb"
|
||||
local_size = None
|
||||
|
||||
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
|
||||
with open(filepath, "rb") as file_obj:
|
||||
if validate_file(file_obj, hash_value, hash_type):
|
||||
return
|
||||
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||||
|
||||
with open(filepath, mode) as fpointer:
|
||||
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
|
||||
fpointer.write(chunk)
|
||||
|
||||
with open(filepath, "rb") as file_obj:
|
||||
if hash_value and not validate_file(file_obj, hash_value, hash_type):
|
||||
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||||
|
||||
|
||||
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
|
||||
"""Validate a given file object with its hash.
|
||||
|
||||
Args:
|
||||
file_obj: File object to read from.
|
||||
hash_value (str): Hash for url.
|
||||
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||||
|
||||
Returns:
|
||||
bool: return True if its a valid file, else False.
|
||||
"""
|
||||
|
||||
if hash_type == "sha256":
|
||||
hash_func = hashlib.sha256()
|
||||
elif hash_type == "md5":
|
||||
hash_func = hashlib.md5()
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
while True:
|
||||
# Read by chunk to avoid filling memory
|
||||
chunk = file_obj.read(1024**2)
|
||||
if not chunk:
|
||||
break
|
||||
hash_func.update(chunk)
|
||||
|
||||
return hash_func.hexdigest() == hash_value
|
||||
|
||||
|
||||
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
|
||||
"""Extract archive.
|
||||
Args:
|
||||
from_path (str): the path of the archive.
|
||||
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
|
||||
(Default: ``None``)
|
||||
overwrite (bool, optional): overwrite existing files (Default: ``False``)
|
||||
|
||||
Returns:
|
||||
list: List of paths to extracted files even if not overwritten.
|
||||
"""
|
||||
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
try:
|
||||
with tarfile.open(from_path, "r") as tar:
|
||||
logging.info("Opened tar file %s.", from_path)
|
||||
files = []
|
||||
for file_ in tar: # type: Any
|
||||
file_path = os.path.join(to_path, file_.name)
|
||||
if file_.isfile():
|
||||
files.append(file_path)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("%s already extracted.", file_path)
|
||||
if not overwrite:
|
||||
continue
|
||||
tar.extract(file_, to_path)
|
||||
return files
|
||||
except tarfile.ReadError:
|
||||
pass
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(from_path, "r") as zfile:
|
||||
logging.info("Opened zip file %s.", from_path)
|
||||
files = zfile.namelist()
|
||||
for file_ in files:
|
||||
file_path = os.path.join(to_path, file_)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("%s already extracted.", file_path)
|
||||
if not overwrite:
|
||||
continue
|
||||
zfile.extract(file_, to_path)
|
||||
return files
|
||||
except zipfile.BadZipFile:
|
||||
pass
|
||||
|
||||
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
|
||||
|
||||
|
||||
def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str):
|
||||
"""Download dataset from kaggle.
|
||||
Args:
|
||||
dataset_path (str):
|
||||
This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning'
|
||||
dataset_name (str): Name of the folder the dataset will be saved in.
|
||||
output_path (str): Path of the location you want the dataset folder to be saved to.
|
||||
"""
|
||||
data_path = os.path.join(output_path, dataset_name)
|
||||
try:
|
||||
import kaggle # pylint: disable=import-outside-toplevel
|
||||
|
||||
kaggle.api.authenticate()
|
||||
print(f"""\nDownloading {dataset_name}...""")
|
||||
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
|
||||
except OSError:
|
||||
print(
|
||||
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
|
||||
|
||||
|
||||
def download_ljspeech(path: str):
|
||||
"""Download and extract LJSpeech dataset
|
||||
|
||||
Args:
|
||||
path (str): path to the directory where the dataset will be stored.
|
||||
"""
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
def download_vctk(path: str, use_kaggle: Optional[bool] = False):
|
||||
"""Download and extract VCTK dataset.
|
||||
|
||||
Args:
|
||||
path (str): path to the directory where the dataset will be stored.
|
||||
|
||||
use_kaggle (bool, optional): Downloads vctk dataset from kaggle. Is generally faster. Defaults to False.
|
||||
"""
|
||||
if use_kaggle:
|
||||
download_kaggle_dataset("mfekadu/english-multispeaker-corpus-for-voice-cloning", "VCTK", path)
|
||||
else:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
def download_tweb(path: str):
|
||||
"""Download and extract Tweb dataset
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory where the dataset will be stored.
|
||||
"""
|
||||
download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path)
|
||||
|
||||
|
||||
def download_libri_tts(path: str, subset: Optional[str] = "all"):
|
||||
"""Download and extract libri tts dataset.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory where the dataset will be stored.
|
||||
|
||||
subset (str, optional): Name of the subset to download. If you only want to download a certain
|
||||
portion specify it here. Defaults to 'all'.
|
||||
"""
|
||||
|
||||
subset_dict = {
|
||||
"libri-tts-clean-100": "http://www.openslr.org/resources/60/train-clean-100.tar.gz",
|
||||
"libri-tts-clean-360": "http://www.openslr.org/resources/60/train-clean-360.tar.gz",
|
||||
"libri-tts-other-500": "http://www.openslr.org/resources/60/train-other-500.tar.gz",
|
||||
"libri-tts-dev-clean": "http://www.openslr.org/resources/60/dev-clean.tar.gz",
|
||||
"libri-tts-dev-other": "http://www.openslr.org/resources/60/dev-other.tar.gz",
|
||||
"libri-tts-test-clean": "http://www.openslr.org/resources/60/test-clean.tar.gz",
|
||||
"libri-tts-test-other": "http://www.openslr.org/resources/60/test-other.tar.gz",
|
||||
}
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
if subset == "all":
|
||||
for sub, val in subset_dict.items():
|
||||
print(f" > Downloading {sub}...")
|
||||
download_url(val, path)
|
||||
basename = os.path.basename(val)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
print(" > All subsets downloaded")
|
||||
else:
|
||||
url = subset_dict[subset]
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
def download_thorsten_de(path: str):
|
||||
"""Download and extract Thorsten german male voice dataset.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory where the dataset will be stored.
|
||||
"""
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = "https://www.openslr.org/resources/95/thorsten-de_v02.tgz"
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
def download_mailabs(path: str, language: str = "english"):
|
||||
"""Download and extract Mailabs dataset.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory where the dataset will be stored.
|
||||
|
||||
language (str): Language subset to download. Defaults to english.
|
||||
"""
|
||||
language_dict = {
|
||||
"english": "https://data.solak.de/data/Training/stt_tts/en_US.tgz",
|
||||
"german": "https://data.solak.de/data/Training/stt_tts/de_DE.tgz",
|
||||
"french": "https://data.solak.de/data/Training/stt_tts/fr_FR.tgz",
|
||||
"italian": "https://data.solak.de/data/Training/stt_tts/it_IT.tgz",
|
||||
"spanish": "https://data.solak.de/data/Training/stt_tts/es_ES.tgz",
|
||||
}
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = language_dict[language]
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
@@ -0,0 +1,239 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
||||
|
||||
def to_cuda(x: torch.Tensor) -> torch.Tensor:
|
||||
if x is None:
|
||||
return None
|
||||
if torch.is_tensor(x):
|
||||
x = x.contiguous()
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda(non_blocking=True)
|
||||
return x
|
||||
|
||||
|
||||
def get_cuda():
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
return use_cuda, device
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
try:
|
||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||
current = next(line for line in out.split("\n") if line.startswith("*"))
|
||||
current.replace("* ", "")
|
||||
except subprocess.CalledProcessError:
|
||||
current = "inside_docker"
|
||||
except (FileNotFoundError, StopIteration) as e:
|
||||
current = "unknown"
|
||||
return current
|
||||
|
||||
|
||||
def get_commit_hash():
|
||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||
# try:
|
||||
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||
# 'HEAD']) # Verify client is clean
|
||||
# except:
|
||||
# raise RuntimeError(
|
||||
# " !! Commit before training to get the commit hash.")
|
||||
try:
|
||||
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
|
||||
# Not copying .git folder into docker container
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
commit = "0000000"
|
||||
return commit
|
||||
|
||||
|
||||
def get_experiment_folder_path(root_path, model_name):
|
||||
"""Get an experiment folder path with the current date and time"""
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||
return output_folder
|
||||
|
||||
|
||||
def remove_experiment_folder(experiment_path):
|
||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||
fs = fsspec.get_mapper(experiment_path).fs
|
||||
checkpoint_files = fs.glob(experiment_path + "/*.pth")
|
||||
if not checkpoint_files:
|
||||
if fs.exists(experiment_path):
|
||||
fs.rm(experiment_path, recursive=True)
|
||||
print(" ! Run is removed from {}".format(experiment_path))
|
||||
else:
|
||||
print(" ! Run is kept in {}".format(experiment_path))
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
text = text.replace("Tts", "TTS")
|
||||
text = text.replace("vc", "VC")
|
||||
return text
|
||||
|
||||
|
||||
def find_module(module_path: str, module_name: str) -> object:
|
||||
module_name = module_name.lower()
|
||||
module = importlib.import_module(module_path + "." + module_name)
|
||||
class_name = to_camel(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_class(module_path: str) -> object:
|
||||
"""Import a class from a module path.
|
||||
|
||||
Args:
|
||||
module_path (str): The module path of the class.
|
||||
|
||||
Returns:
|
||||
object: The imported class.
|
||||
"""
|
||||
class_name = module_path.split(".")[-1]
|
||||
module_path = ".".join(module_path.split(".")[:-1])
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def get_import_path(obj: object) -> str:
|
||||
"""Get the import path of a class.
|
||||
|
||||
Args:
|
||||
obj (object): The class object.
|
||||
|
||||
Returns:
|
||||
str: The import path of the class.
|
||||
"""
|
||||
return ".".join([type(obj).__module__, type(obj).__name__])
|
||||
|
||||
|
||||
def get_user_data_dir(appname):
|
||||
TTS_HOME = os.environ.get("TTS_HOME")
|
||||
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
|
||||
if TTS_HOME is not None:
|
||||
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
|
||||
elif XDG_DATA_HOME is not None:
|
||||
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
|
||||
elif sys.platform == "win32":
|
||||
import winreg # pylint: disable=import-outside-toplevel
|
||||
|
||||
key = winreg.OpenKey(
|
||||
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||
)
|
||||
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
|
||||
ans = Path(dir_).resolve(strict=False)
|
||||
elif sys.platform == "darwin":
|
||||
ans = Path("~/Library/Application Support/").expanduser()
|
||||
else:
|
||||
ans = Path.home().joinpath(".local/share")
|
||||
return ans.joinpath(appname)
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
||||
# 3. skip reinit layers
|
||||
if c.has("reinit_layers") and c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
|
||||
"""Format kwargs to hande auxilary inputs to models.
|
||||
|
||||
Args:
|
||||
def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`.
|
||||
kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model.
|
||||
|
||||
Returns:
|
||||
Dict: arguments with formatted auxilary inputs.
|
||||
"""
|
||||
kwargs = kwargs.copy()
|
||||
for name in def_args:
|
||||
if name not in kwargs or kwargs[name] is None:
|
||||
kwargs[name] = def_args[name]
|
||||
return kwargs
|
||||
|
||||
|
||||
class KeepAverage:
|
||||
def __init__(self):
|
||||
self.avg_values = {}
|
||||
self.iters = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.avg_values[key]
|
||||
|
||||
def items(self):
|
||||
return self.avg_values.items()
|
||||
|
||||
def add_value(self, name, init_val=0, init_iter=0):
|
||||
self.avg_values[name] = init_val
|
||||
self.iters[name] = init_iter
|
||||
|
||||
def update_value(self, name, value, weighted_avg=False):
|
||||
if name not in self.avg_values:
|
||||
# add value if not exist before
|
||||
self.add_value(name, init_val=value)
|
||||
else:
|
||||
# else update existing value
|
||||
if weighted_avg:
|
||||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||
self.iters[name] += 1
|
||||
else:
|
||||
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
|
||||
self.iters[name] += 1
|
||||
self.avg_values[name] /= self.iters[name]
|
||||
|
||||
def add_values(self, name_dict):
|
||||
for key, value in name_dict.items():
|
||||
self.add_value(key, init_val=value)
|
||||
|
||||
def update_values(self, value_dict):
|
||||
for key, value in value_dict.items():
|
||||
self.update_value(key, value)
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
|
||||
|
||||
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
|
||||
lg = logging.getLogger(logger_name)
|
||||
formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S")
|
||||
lg.setLevel(level)
|
||||
if tofile:
|
||||
log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp()))
|
||||
fh = logging.FileHandler(log_file, mode="w")
|
||||
fh.setFormatter(formatter)
|
||||
lg.addHandler(fh)
|
||||
if screen:
|
||||
sh = logging.StreamHandler()
|
||||
sh.setFormatter(formatter)
|
||||
lg.addHandler(sh)
|
||||
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
import pickle as pickle_tts
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
||||
|
||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||
"""Overload default pickler to solve module renaming problem"""
|
||||
|
||||
def find_class(self, module, name):
|
||||
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
"""A custom dict which converts dict keys
|
||||
to class attributes"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def load_fsspec(
|
||||
path: str,
|
||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||
cache: bool = True,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
path: Any path or url supported by fsspec.
|
||||
map_location: torch.device or str.
|
||||
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
|
||||
**kwargs: Keyword arguments forwarded to torch.load.
|
||||
|
||||
Returns:
|
||||
Object stored in path.
|
||||
"""
|
||||
is_local = os.path.isdir(path) or os.path.isfile(path)
|
||||
if cache and not is_local:
|
||||
with fsspec.open(
|
||||
f"filecache::{path}",
|
||||
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
|
||||
mode="rb",
|
||||
) as f:
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
else:
|
||||
with fsspec.open(path, "rb") as f:
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
model, checkpoint_path, use_cuda=False, eval=False, cache=False
|
||||
): # pylint: disable=redefined-builtin
|
||||
try:
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
|
||||
model.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
if eval:
|
||||
model.eval()
|
||||
return model, state
|
||||
@@ -0,0 +1,621 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tarfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from shutil import copyfile, rmtree
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import fsspec
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config, read_json_with_comments
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
||||
LICENSE_URLS = {
|
||||
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
|
||||
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||
"mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||
"mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||
"mit": "https://choosealicense.com/licenses/mit/",
|
||||
"apache 2.0": "https://choosealicense.com/licenses/apache-2.0/",
|
||||
"apache2": "https://choosealicense.com/licenses/apache-2.0/",
|
||||
"cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/",
|
||||
"cpml": "https://coqui.ai/cpml.txt",
|
||||
}
|
||||
|
||||
|
||||
class ModelManager(object):
|
||||
tqdm_progress = None
|
||||
"""Manage TTS models defined in .models.json.
|
||||
It provides an interface to list and download
|
||||
models defines in '.model.json'
|
||||
|
||||
Models are downloaded under '.TTS' folder in the user's
|
||||
home path.
|
||||
|
||||
Args:
|
||||
models_file (str): path to .model.json file. Defaults to None.
|
||||
output_prefix (str): prefix to `tts` to download models. Defaults to None
|
||||
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
|
||||
verbose (bool): print info. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True):
|
||||
super().__init__()
|
||||
self.progress_bar = progress_bar
|
||||
self.verbose = verbose
|
||||
if output_prefix is None:
|
||||
self.output_prefix = get_user_data_dir("tts")
|
||||
else:
|
||||
self.output_prefix = os.path.join(output_prefix, "tts")
|
||||
self.models_dict = None
|
||||
if models_file is not None:
|
||||
self.read_models_file(models_file)
|
||||
else:
|
||||
# try the default location
|
||||
path = Path(__file__).parent / "../.models.json"
|
||||
self.read_models_file(path)
|
||||
|
||||
def read_models_file(self, file_path):
|
||||
"""Read .models.json as a dict
|
||||
|
||||
Args:
|
||||
file_path (str): path to .models.json.
|
||||
"""
|
||||
self.models_dict = read_json_with_comments(file_path)
|
||||
|
||||
def _list_models(self, model_type, model_count=0):
|
||||
if self.verbose:
|
||||
print("\n Name format: type/language/dataset/model")
|
||||
model_list = []
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
for model in self.models_dict[model_type][lang][dataset]:
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if self.verbose:
|
||||
if os.path.exists(output_path):
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||
else:
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
|
||||
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||
model_count += 1
|
||||
return model_list
|
||||
|
||||
def _list_for_model_type(self, model_type):
|
||||
models_name_list = []
|
||||
model_count = 1
|
||||
models_name_list.extend(self._list_models(model_type, model_count))
|
||||
return models_name_list
|
||||
|
||||
def list_models(self):
|
||||
models_name_list = []
|
||||
model_count = 1
|
||||
for model_type in self.models_dict:
|
||||
model_list = self._list_models(model_type, model_count)
|
||||
models_name_list.extend(model_list)
|
||||
return models_name_list
|
||||
|
||||
def model_info_by_idx(self, model_query):
|
||||
"""Print the description of the model from .models.json file using model_idx
|
||||
|
||||
Args:
|
||||
model_query (str): <model_tye>/<model_idx>
|
||||
"""
|
||||
model_name_list = []
|
||||
model_type, model_query_idx = model_query.split("/")
|
||||
try:
|
||||
model_query_idx = int(model_query_idx)
|
||||
if model_query_idx <= 0:
|
||||
print("> model_query_idx should be a positive integer!")
|
||||
return
|
||||
except:
|
||||
print("> model_query_idx should be an integer!")
|
||||
return
|
||||
model_count = 0
|
||||
if model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
for model in self.models_dict[model_type][lang][dataset]:
|
||||
model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||
model_count += 1
|
||||
else:
|
||||
print(f"> model_type {model_type} does not exist in the list.")
|
||||
return
|
||||
if model_query_idx > model_count:
|
||||
print(f"model query idx exceeds the number of available models [{model_count}] ")
|
||||
else:
|
||||
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
|
||||
print(f"> model type : {model_type}")
|
||||
print(f"> language supported : {lang}")
|
||||
print(f"> dataset used : {dataset}")
|
||||
print(f"> model name : {model}")
|
||||
if "description" in self.models_dict[model_type][lang][dataset][model]:
|
||||
print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
|
||||
else:
|
||||
print("> description : coming soon")
|
||||
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
|
||||
print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")
|
||||
|
||||
def model_info_by_full_name(self, model_query_name):
|
||||
"""Print the description of the model from .models.json file using model_full_name
|
||||
|
||||
Args:
|
||||
model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
|
||||
"""
|
||||
model_type, lang, dataset, model = model_query_name.split("/")
|
||||
if model_type in self.models_dict:
|
||||
if lang in self.models_dict[model_type]:
|
||||
if dataset in self.models_dict[model_type][lang]:
|
||||
if model in self.models_dict[model_type][lang][dataset]:
|
||||
print(f"> model type : {model_type}")
|
||||
print(f"> language supported : {lang}")
|
||||
print(f"> dataset used : {dataset}")
|
||||
print(f"> model name : {model}")
|
||||
if "description" in self.models_dict[model_type][lang][dataset][model]:
|
||||
print(
|
||||
f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
|
||||
)
|
||||
else:
|
||||
print("> description : coming soon")
|
||||
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
|
||||
print(
|
||||
f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
|
||||
)
|
||||
else:
|
||||
print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
|
||||
else:
|
||||
print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
|
||||
else:
|
||||
print(f"> lang {lang} does not exist for {model_type}.")
|
||||
else:
|
||||
print(f"> model_type {model_type} does not exist in the list.")
|
||||
|
||||
def list_tts_models(self):
|
||||
"""Print all `TTS` models and return a list of model names
|
||||
|
||||
Format is `language/dataset/model`
|
||||
"""
|
||||
return self._list_for_model_type("tts_models")
|
||||
|
||||
def list_vocoder_models(self):
|
||||
"""Print all the `vocoder` models and return a list of model names
|
||||
|
||||
Format is `language/dataset/model`
|
||||
"""
|
||||
return self._list_for_model_type("vocoder_models")
|
||||
|
||||
def list_vc_models(self):
|
||||
"""Print all the voice conversion models and return a list of model names
|
||||
|
||||
Format is `language/dataset/model`
|
||||
"""
|
||||
return self._list_for_model_type("voice_conversion_models")
|
||||
|
||||
def list_langs(self):
|
||||
"""Print all the available languages"""
|
||||
print(" Name format: type/language")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
print(f" >: {model_type}/{lang} ")
|
||||
|
||||
def list_datasets(self):
|
||||
"""Print all the datasets"""
|
||||
print(" Name format: type/language/dataset")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
print(f" >: {model_type}/{lang}/{dataset}")
|
||||
|
||||
@staticmethod
|
||||
def print_model_license(model_item: Dict):
|
||||
"""Print the license of a model
|
||||
|
||||
Args:
|
||||
model_item (dict): model item in the models.json
|
||||
"""
|
||||
if "license" in model_item and model_item["license"].strip() != "":
|
||||
print(f" > Model's license - {model_item['license']}")
|
||||
if model_item["license"].lower() in LICENSE_URLS:
|
||||
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
|
||||
else:
|
||||
print(" > Check https://opensource.org/licenses for more info.")
|
||||
else:
|
||||
print(" > Model's license - No license information available")
|
||||
|
||||
def _download_github_model(self, model_item: Dict, output_path: str):
|
||||
if isinstance(model_item["github_rls_url"], list):
|
||||
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
else:
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
|
||||
def _download_hf_model(self, model_item: Dict, output_path: str):
|
||||
if isinstance(model_item["hf_url"], list):
|
||||
self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
|
||||
else:
|
||||
self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
|
||||
|
||||
def download_fairseq_model(self, model_name, output_path):
|
||||
URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/"
|
||||
_, lang, _, _ = model_name.split("/")
|
||||
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
|
||||
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
|
||||
|
||||
@staticmethod
|
||||
def set_model_url(model_item: Dict):
|
||||
model_item["model_url"] = None
|
||||
if "github_rls_url" in model_item:
|
||||
model_item["model_url"] = model_item["github_rls_url"]
|
||||
elif "hf_url" in model_item:
|
||||
model_item["model_url"] = model_item["hf_url"]
|
||||
elif "fairseq" in model_item["model_name"]:
|
||||
model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
|
||||
elif "xtts" in model_item["model_name"]:
|
||||
model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
|
||||
return model_item
|
||||
|
||||
def _set_model_item(self, model_name):
|
||||
# fetch model info from the dict
|
||||
if "fairseq" in model_name:
|
||||
model_type = "tts_models"
|
||||
lang = model_name.split("/")[1]
|
||||
model_item = {
|
||||
"model_type": "tts_models",
|
||||
"license": "CC BY-NC 4.0",
|
||||
"default_vocoder": None,
|
||||
"author": "fairseq",
|
||||
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
|
||||
}
|
||||
model_item["model_name"] = model_name
|
||||
elif "xtts" in model_name and len(model_name.split("/")) != 4:
|
||||
# loading xtts models with only model name (e.g. xtts_v2.0.2)
|
||||
# check model name has the version number with regex
|
||||
version_regex = r"v\d+\.\d+\.\d+"
|
||||
if re.search(version_regex, model_name):
|
||||
model_version = model_name.split("_")[-1]
|
||||
else:
|
||||
model_version = "main"
|
||||
model_type = "tts_models"
|
||||
lang = "multilingual"
|
||||
dataset = "multi-dataset"
|
||||
model = model_name
|
||||
model_item = {
|
||||
"default_vocoder": None,
|
||||
"license": "CPML",
|
||||
"contact": "info@coqui.ai",
|
||||
"tos_required": True,
|
||||
"hf_url": [
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5",
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth",
|
||||
],
|
||||
}
|
||||
else:
|
||||
# get model from models.json
|
||||
model_type, lang, dataset, model = model_name.split("/")
|
||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||
model_item["model_type"] = model_type
|
||||
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
|
||||
model_item = self.set_model_url(model_item)
|
||||
return model_item, model_full_name, model, md5hash
|
||||
|
||||
@staticmethod
|
||||
def ask_tos(model_full_path):
|
||||
"""Ask the user to agree to the terms of service"""
|
||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||
print(" > You must confirm the following:")
|
||||
print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"')
|
||||
print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]')
|
||||
answer = input(" | | > ")
|
||||
if answer.lower() == "y":
|
||||
with open(tos_path, "w", encoding="utf-8") as f:
|
||||
f.write("I have read, understood and agreed to the Terms and Conditions.")
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def tos_agreed(model_item, model_full_path):
|
||||
"""Check if the user has agreed to the terms of service"""
|
||||
if "tos_required" in model_item and model_item["tos_required"]:
|
||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||
if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1":
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def create_dir_and_download_model(self, model_name, model_item, output_path):
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
# handle TOS
|
||||
if not self.tos_agreed(model_item, output_path):
|
||||
if not self.ask_tos(output_path):
|
||||
os.rmdir(output_path)
|
||||
raise Exception(" [!] You must agree to the terms of service to use this model.")
|
||||
print(f" > Downloading model to {output_path}")
|
||||
try:
|
||||
if "fairseq" in model_name:
|
||||
self.download_fairseq_model(model_name, output_path)
|
||||
elif "github_rls_url" in model_item:
|
||||
self._download_github_model(model_item, output_path)
|
||||
elif "hf_url" in model_item:
|
||||
self._download_hf_model(model_item, output_path)
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f" > Failed to download the model file to {output_path}")
|
||||
rmtree(output_path)
|
||||
raise e
|
||||
self.print_model_license(model_item=model_item)
|
||||
|
||||
def check_if_configs_are_equal(self, model_name, model_item, output_path):
|
||||
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
|
||||
config_local = json.load(f)
|
||||
remote_url = None
|
||||
for url in model_item["hf_url"]:
|
||||
if "config.json" in url:
|
||||
remote_url = url
|
||||
break
|
||||
|
||||
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
|
||||
config_remote = json.load(f)
|
||||
|
||||
if not config_local == config_remote:
|
||||
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
Model name is in the format
|
||||
'type/language/dataset/model'
|
||||
e.g. 'tts_model/en/ljspeech/tacotron'
|
||||
|
||||
Every model must have the following files:
|
||||
- *.pth : pytorch model checkpoint file.
|
||||
- config.json : model config file.
|
||||
- scale_stats.npy (if exist): scale values for preprocessing.
|
||||
|
||||
Args:
|
||||
model_name (str): model name as explained above.
|
||||
"""
|
||||
model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
|
||||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
if md5sum is not None:
|
||||
md5sum_file = os.path.join(output_path, "hash.md5")
|
||||
if os.path.isfile(md5sum_file):
|
||||
with open(md5sum_file, mode="r") as f:
|
||||
if not f.read() == md5sum:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
# if the configs are different, redownload it
|
||||
# ToDo: we need a better way to handle it
|
||||
if "xtts" in model_name:
|
||||
try:
|
||||
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
# find downloaded files
|
||||
output_model_path = output_path
|
||||
output_config_path = None
|
||||
if (
|
||||
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
|
||||
): # TODO:This is stupid but don't care for now.
|
||||
output_model_path, output_config_path = self._find_files(output_path)
|
||||
# update paths in the config.json
|
||||
self._update_paths(output_path, output_config_path)
|
||||
return output_model_path, output_config_path, model_item
|
||||
|
||||
@staticmethod
|
||||
def _find_files(output_path: str) -> Tuple[str, str]:
|
||||
"""Find the model and config files in the output path
|
||||
|
||||
Args:
|
||||
output_path (str): path to the model files
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: path to the model file and config file
|
||||
"""
|
||||
model_file = None
|
||||
config_file = None
|
||||
for file_name in os.listdir(output_path):
|
||||
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
|
||||
model_file = os.path.join(output_path, file_name)
|
||||
elif file_name == "config.json":
|
||||
config_file = os.path.join(output_path, file_name)
|
||||
if model_file is None:
|
||||
raise ValueError(" [!] Model file not found in the output path")
|
||||
if config_file is None:
|
||||
raise ValueError(" [!] Config file not found in the output path")
|
||||
return model_file, config_file
|
||||
|
||||
@staticmethod
|
||||
def _find_speaker_encoder(output_path: str) -> str:
|
||||
"""Find the speaker encoder file in the output path
|
||||
|
||||
Args:
|
||||
output_path (str): path to the model files
|
||||
|
||||
Returns:
|
||||
str: path to the speaker encoder file
|
||||
"""
|
||||
speaker_encoder_file = None
|
||||
for file_name in os.listdir(output_path):
|
||||
if file_name in ["model_se.pth", "model_se.pth.tar"]:
|
||||
speaker_encoder_file = os.path.join(output_path, file_name)
|
||||
return speaker_encoder_file
|
||||
|
||||
def _update_paths(self, output_path: str, config_path: str) -> None:
|
||||
"""Update paths for certain files in config.json after download.
|
||||
|
||||
Args:
|
||||
output_path (str): local path the model is downloaded to.
|
||||
config_path (str): local config.json path.
|
||||
"""
|
||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
||||
output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth")
|
||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
||||
output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth")
|
||||
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
|
||||
speaker_encoder_model_path = self._find_speaker_encoder(output_path)
|
||||
|
||||
# update the scale_path.npy file path in the model config.json
|
||||
self._update_path("audio.stats_path", output_stats_path, config_path)
|
||||
|
||||
# update the speakers.json file path in the model config.json to the current path
|
||||
self._update_path("d_vector_file", output_d_vector_file_path, config_path)
|
||||
self._update_path("d_vector_file", output_d_vector_file_pth_path, config_path)
|
||||
self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
|
||||
self._update_path("model_args.d_vector_file", output_d_vector_file_pth_path, config_path)
|
||||
|
||||
# update the speaker_ids.json file path in the model config.json to the current path
|
||||
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
||||
self._update_path("speakers_file", output_speaker_ids_file_pth_path, config_path)
|
||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_pth_path, config_path)
|
||||
|
||||
# update the speaker_encoder file path in the model config.json to the current path
|
||||
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||
|
||||
@staticmethod
|
||||
def _update_path(field_name, new_path, config_path):
|
||||
"""Update the path in the model config.json for the current environment after download"""
|
||||
if new_path and os.path.exists(new_path):
|
||||
config = load_config(config_path)
|
||||
field_names = field_name.split(".")
|
||||
if len(field_names) > 1:
|
||||
# field name points to a sub-level field
|
||||
sub_conf = config
|
||||
for fd in field_names[:-1]:
|
||||
if fd in sub_conf:
|
||||
sub_conf = sub_conf[fd]
|
||||
else:
|
||||
return
|
||||
if isinstance(sub_conf[field_names[-1]], list):
|
||||
sub_conf[field_names[-1]] = [new_path]
|
||||
else:
|
||||
sub_conf[field_names[-1]] = new_path
|
||||
else:
|
||||
# field name points to a top-level field
|
||||
if not field_name in config:
|
||||
return
|
||||
if isinstance(config[field_name], list):
|
||||
config[field_name] = [new_path]
|
||||
else:
|
||||
config[field_name] = new_path
|
||||
config.save_json(config_path)
|
||||
|
||||
@staticmethod
|
||||
def _download_zip_file(file_url, output_folder, progress_bar):
|
||||
"""Download the github releases"""
|
||||
# download the file
|
||||
r = requests.get(file_url, stream=True)
|
||||
# extract the file
|
||||
try:
|
||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||
with open(temp_zip_name, "wb") as file:
|
||||
for data in r.iter_content(block_size):
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress.update(len(data))
|
||||
file.write(data)
|
||||
with zipfile.ZipFile(temp_zip_name) as z:
|
||||
z.extractall(output_folder)
|
||||
os.remove(temp_zip_name) # delete zip after extract
|
||||
except zipfile.BadZipFile:
|
||||
print(f" > Error: Bad zip file - {file_url}")
|
||||
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
||||
# move the files to the outer path
|
||||
for file_path in z.namelist():
|
||||
src_path = os.path.join(output_folder, file_path)
|
||||
if os.path.isfile(src_path):
|
||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||
if src_path != dst_path:
|
||||
copyfile(src_path, dst_path)
|
||||
# remove redundant (hidden or not) folders
|
||||
for file_path in z.namelist():
|
||||
if os.path.isdir(os.path.join(output_folder, file_path)):
|
||||
rmtree(os.path.join(output_folder, file_path))
|
||||
|
||||
@staticmethod
|
||||
def _download_tar_file(file_url, output_folder, progress_bar):
|
||||
"""Download the github releases"""
|
||||
# download the file
|
||||
r = requests.get(file_url, stream=True)
|
||||
# extract the file
|
||||
try:
|
||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||
with open(temp_tar_name, "wb") as file:
|
||||
for data in r.iter_content(block_size):
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress.update(len(data))
|
||||
file.write(data)
|
||||
with tarfile.open(temp_tar_name) as t:
|
||||
t.extractall(output_folder)
|
||||
tar_names = t.getnames()
|
||||
os.remove(temp_tar_name) # delete tar after extract
|
||||
except tarfile.ReadError:
|
||||
print(f" > Error: Bad tar file - {file_url}")
|
||||
raise tarfile.ReadError # pylint: disable=raise-missing-from
|
||||
# move the files to the outer path
|
||||
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
|
||||
src_path = os.path.join(output_folder, tar_names[0], file_path)
|
||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||
if src_path != dst_path:
|
||||
copyfile(src_path, dst_path)
|
||||
# remove the extracted folder
|
||||
rmtree(os.path.join(output_folder, tar_names[0]))
|
||||
|
||||
@staticmethod
|
||||
def _download_model_files(file_urls, output_folder, progress_bar):
|
||||
"""Download the github releases"""
|
||||
for file_url in file_urls:
|
||||
# download the file
|
||||
r = requests.get(file_url, stream=True)
|
||||
# extract the file
|
||||
bease_filename = file_url.split("/")[-1]
|
||||
temp_zip_name = os.path.join(output_folder, bease_filename)
|
||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
with open(temp_zip_name, "wb") as file:
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
for data in r.iter_content(block_size):
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress.update(len(data))
|
||||
file.write(data)
|
||||
|
||||
@staticmethod
|
||||
def _check_dict_key(my_dict, key):
|
||||
if key in my_dict.keys() and my_dict[key] is not None:
|
||||
if not isinstance(key, str):
|
||||
return True
|
||||
if isinstance(key, str) and len(my_dict[key]) > 0:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,105 @@
|
||||
# modified from https://github.com/LiyuanLucasLiu/RAdam
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if eps < 0.0:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
|
||||
self.degenerated_to_sgd = degenerated_to_sgd
|
||||
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
|
||||
for param in params:
|
||||
if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]):
|
||||
param["buffer"] = [[None, None, None] for _ in range(10)]
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state): # pylint: disable=useless-super-delegation
|
||||
super().__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("RAdam does not support sparse gradients")
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
|
||||
state["step"] += 1
|
||||
buffered = group["buffer"][int(state["step"] % 10)]
|
||||
if state["step"] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state["step"]
|
||||
beta2_t = beta2 ** state["step"]
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = math.sqrt(
|
||||
(1 - beta2_t)
|
||||
* (N_sma - 4)
|
||||
/ (N_sma_max - 4)
|
||||
* (N_sma - 2)
|
||||
/ N_sma
|
||||
* N_sma_max
|
||||
/ (N_sma_max - 2)
|
||||
) / (1 - beta1 ** state["step"])
|
||||
elif self.degenerated_to_sgd:
|
||||
step_size = 1.0 / (1 - beta1 ** state["step"])
|
||||
else:
|
||||
step_size = -1
|
||||
buffered[2] = step_size
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
|
||||
p.data.copy_(p_data_fp32)
|
||||
elif step_size > 0:
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
|
||||
p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
@@ -0,0 +1,201 @@
|
||||
import math
|
||||
import random
|
||||
from typing import Callable, List, Union
|
||||
|
||||
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
|
||||
|
||||
|
||||
class SubsetSampler(Sampler):
|
||||
"""
|
||||
Samples elements sequentially from a given list of indices.
|
||||
|
||||
Args:
|
||||
indices (list): a sequence of indices
|
||||
"""
|
||||
|
||||
def __init__(self, indices):
|
||||
super().__init__(indices)
|
||||
self.indices = indices
|
||||
|
||||
def __iter__(self):
|
||||
return (self.indices[i] for i in range(len(self.indices)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class PerfectBatchSampler(Sampler):
|
||||
"""
|
||||
Samples a mini-batch of indices for a balanced class batching
|
||||
|
||||
Args:
|
||||
dataset_items(list): dataset items to sample from.
|
||||
classes (list): list of classes of dataset_items to sample from.
|
||||
batch_size (int): total number of samples to be sampled in a mini-batch.
|
||||
num_gpus (int): number of GPU in the data parallel mode.
|
||||
shuffle (bool): if True, samples randomly, otherwise samples sequentially.
|
||||
drop_last (bool): if True, drops last incomplete batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_items,
|
||||
classes,
|
||||
batch_size,
|
||||
num_classes_in_batch,
|
||||
num_gpus=1,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
label_key="class_name",
|
||||
):
|
||||
super().__init__(dataset_items)
|
||||
assert (
|
||||
batch_size % (num_classes_in_batch * num_gpus) == 0
|
||||
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
|
||||
|
||||
label_indices = {}
|
||||
for idx, item in enumerate(dataset_items):
|
||||
label = item[label_key]
|
||||
if label not in label_indices.keys():
|
||||
label_indices[label] = [idx]
|
||||
else:
|
||||
label_indices[label].append(idx)
|
||||
|
||||
if shuffle:
|
||||
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
|
||||
else:
|
||||
self._samplers = [SubsetSampler(label_indices[key]) for key in classes]
|
||||
|
||||
self._batch_size = batch_size
|
||||
self._drop_last = drop_last
|
||||
self._dp_devices = num_gpus
|
||||
self._num_classes_in_batch = num_classes_in_batch
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
if self._num_classes_in_batch != len(self._samplers):
|
||||
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||
else:
|
||||
valid_samplers_idx = None
|
||||
|
||||
iters = [iter(s) for s in self._samplers]
|
||||
done = False
|
||||
|
||||
while True:
|
||||
b = []
|
||||
for i, it in enumerate(iters):
|
||||
if valid_samplers_idx is not None and i not in valid_samplers_idx:
|
||||
continue
|
||||
idx = next(it, None)
|
||||
if idx is None:
|
||||
done = True
|
||||
break
|
||||
b.append(idx)
|
||||
if done:
|
||||
break
|
||||
batch += b
|
||||
if len(batch) == self._batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if valid_samplers_idx is not None:
|
||||
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||
|
||||
if not self._drop_last:
|
||||
if len(batch) > 0:
|
||||
groups = len(batch) // self._num_classes_in_batch
|
||||
if groups % self._dp_devices == 0:
|
||||
yield batch
|
||||
else:
|
||||
batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
class_batch_size = self._batch_size // self._num_classes_in_batch
|
||||
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
class SortedSampler(Sampler):
|
||||
"""Samples elements sequentially, always in the same order.
|
||||
|
||||
Taken from https://github.com/PetrochukM/PyTorch-NLP
|
||||
|
||||
Args:
|
||||
data (iterable): Iterable data.
|
||||
sort_key (callable): Specifies a function of one argument that is used to extract a
|
||||
numerical comparison key from each list element.
|
||||
|
||||
Example:
|
||||
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
|
||||
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data, sort_key: Callable = identity):
|
||||
super().__init__(data)
|
||||
self.data = data
|
||||
self.sort_key = sort_key
|
||||
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
|
||||
zip_ = sorted(zip_, key=lambda r: r[1])
|
||||
self.sorted_indexes = [item[0] for item in zip_]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.sorted_indexes)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class BucketBatchSampler(BatchSampler):
|
||||
"""Bucket batch sampler
|
||||
|
||||
Adapted from https://github.com/PetrochukM/PyTorch-NLP
|
||||
|
||||
Args:
|
||||
sampler (torch.data.utils.sampler.Sampler):
|
||||
batch_size (int): Size of mini-batch.
|
||||
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
|
||||
than `batch_size`.
|
||||
data (list): List of data samples.
|
||||
sort_key (callable, optional): Callable to specify a comparison key for sorting.
|
||||
bucket_size_multiplier (int, optional): Buckets are of size
|
||||
`batch_size * bucket_size_multiplier`.
|
||||
|
||||
Example:
|
||||
>>> sampler = WeightedRandomSampler(weights, len(weights))
|
||||
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampler,
|
||||
data,
|
||||
batch_size,
|
||||
drop_last,
|
||||
sort_key: Union[Callable, List] = identity,
|
||||
bucket_size_multiplier=100,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.data = data
|
||||
self.sort_key = sort_key
|
||||
_bucket_size = batch_size * bucket_size_multiplier
|
||||
if hasattr(sampler, "__len__"):
|
||||
_bucket_size = min(_bucket_size, len(sampler))
|
||||
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
|
||||
|
||||
def __iter__(self):
|
||||
for idxs in self.bucket_sampler:
|
||||
bucket_data = [self.data[idx] for idx in idxs]
|
||||
sorted_sampler = SortedSampler(bucket_data, self.sort_key)
|
||||
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
|
||||
sorted_idxs = [idxs[i] for i in batch_idx]
|
||||
yield sorted_idxs
|
||||
|
||||
def __len__(self):
|
||||
if self.drop_last:
|
||||
return len(self.sampler) // self.batch_size
|
||||
return math.ceil(len(self.sampler) / self.batch_size)
|
||||
@@ -0,0 +1,505 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pysbd
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.models.vits import Vits
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import save_wav
|
||||
from TTS.vc.models import setup_model as setup_vc_model
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||
|
||||
|
||||
class Synthesizer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tts_checkpoint: str = "",
|
||||
tts_config_path: str = "",
|
||||
tts_speakers_file: str = "",
|
||||
tts_languages_file: str = "",
|
||||
vocoder_checkpoint: str = "",
|
||||
vocoder_config: str = "",
|
||||
encoder_checkpoint: str = "",
|
||||
encoder_config: str = "",
|
||||
vc_checkpoint: str = "",
|
||||
vc_config: str = "",
|
||||
model_dir: str = "",
|
||||
voice_dir: str = None,
|
||||
use_cuda: bool = False,
|
||||
) -> None:
|
||||
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
||||
model and synthesize speech from the provided text.
|
||||
|
||||
The text is divided into a list of sentences using `pysbd` and synthesize
|
||||
speech on each sentence separately.
|
||||
|
||||
If you have certain special characters in your text, you need to handle
|
||||
them before providing the text to Synthesizer.
|
||||
|
||||
TODO: set the segmenter based on the source language
|
||||
|
||||
Args:
|
||||
tts_checkpoint (str, optional): path to the tts model file.
|
||||
tts_config_path (str, optional): path to the tts config file.
|
||||
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
||||
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
||||
encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`,
|
||||
encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`,
|
||||
vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`,
|
||||
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
|
||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.tts_checkpoint = tts_checkpoint
|
||||
self.tts_config_path = tts_config_path
|
||||
self.tts_speakers_file = tts_speakers_file
|
||||
self.tts_languages_file = tts_languages_file
|
||||
self.vocoder_checkpoint = vocoder_checkpoint
|
||||
self.vocoder_config = vocoder_config
|
||||
self.encoder_checkpoint = encoder_checkpoint
|
||||
self.encoder_config = encoder_config
|
||||
self.vc_checkpoint = vc_checkpoint
|
||||
self.vc_config = vc_config
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
self.tts_model = None
|
||||
self.vocoder_model = None
|
||||
self.vc_model = None
|
||||
self.speaker_manager = None
|
||||
self.tts_speakers = {}
|
||||
self.language_manager = None
|
||||
self.num_languages = 0
|
||||
self.tts_languages = {}
|
||||
self.d_vector_dim = 0
|
||||
self.seg = self._get_segmenter("en")
|
||||
self.use_cuda = use_cuda
|
||||
self.voice_dir = voice_dir
|
||||
if self.use_cuda:
|
||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||
|
||||
if tts_checkpoint:
|
||||
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
|
||||
if vocoder_checkpoint:
|
||||
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
||||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||
|
||||
if vc_checkpoint:
|
||||
self._load_vc(vc_checkpoint, vc_config, use_cuda)
|
||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||
|
||||
if model_dir:
|
||||
if "fairseq" in model_dir:
|
||||
self._load_fairseq_from_dir(model_dir, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
else:
|
||||
self._load_tts_from_dir(model_dir, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
||||
|
||||
@staticmethod
|
||||
def _get_segmenter(lang: str):
|
||||
"""get the sentence segmenter for the given language.
|
||||
|
||||
Args:
|
||||
lang (str): target language code.
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
return pysbd.Segmenter(language=lang, clean=True)
|
||||
|
||||
def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None:
|
||||
"""Load the voice conversion model.
|
||||
|
||||
1. Load the model config.
|
||||
2. Init the model from the config.
|
||||
3. Load the model weights.
|
||||
4. Move the model to the GPU if CUDA is enabled.
|
||||
|
||||
Args:
|
||||
vc_checkpoint (str): path to the model checkpoint.
|
||||
tts_config_path (str): path to the model config file.
|
||||
use_cuda (bool): enable/disable CUDA use.
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
self.vc_config = load_config(vc_config_path)
|
||||
self.vc_model = setup_vc_model(config=self.vc_config)
|
||||
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
|
||||
if use_cuda:
|
||||
self.vc_model.cuda()
|
||||
|
||||
def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
||||
"""Load the fairseq model from a directory.
|
||||
|
||||
We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory.
|
||||
"""
|
||||
self.tts_config = VitsConfig()
|
||||
self.tts_model = Vits.init_from_config(self.tts_config)
|
||||
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
|
||||
self.tts_config = self.tts_model.config
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
||||
"""Load the TTS model from a directory.
|
||||
|
||||
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
|
||||
"""
|
||||
config = load_config(os.path.join(model_dir, "config.json"))
|
||||
self.tts_config = config
|
||||
self.tts_model = setup_tts_model(config)
|
||||
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||
"""Load the TTS model.
|
||||
|
||||
1. Load the model config.
|
||||
2. Init the model from the config.
|
||||
3. Load the model weights.
|
||||
4. Move the model to the GPU if CUDA is enabled.
|
||||
5. Init the speaker manager in the model.
|
||||
|
||||
Args:
|
||||
tts_checkpoint (str): path to the model checkpoint.
|
||||
tts_config_path (str): path to the model config file.
|
||||
use_cuda (bool): enable/disable CUDA use.
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
self.tts_config = load_config(tts_config_path)
|
||||
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
|
||||
raise ValueError("Phonemizer is not defined in the TTS config.")
|
||||
|
||||
self.tts_model = setup_tts_model(config=self.tts_config)
|
||||
|
||||
if not self.encoder_checkpoint:
|
||||
self._set_speaker_encoder_paths_from_tts_config()
|
||||
|
||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"):
|
||||
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config, use_cuda)
|
||||
|
||||
def _set_speaker_encoder_paths_from_tts_config(self):
|
||||
"""Set the encoder paths from the tts model config for models with speaker encoders."""
|
||||
if hasattr(self.tts_config, "model_args") and hasattr(
|
||||
self.tts_config.model_args, "speaker_encoder_config_path"
|
||||
):
|
||||
self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
|
||||
self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
|
||||
|
||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||
"""Load the vocoder model.
|
||||
|
||||
1. Load the vocoder config.
|
||||
2. Init the AudioProcessor for the vocoder.
|
||||
3. Init the vocoder model from the config.
|
||||
4. Move the model to the GPU if CUDA is enabled.
|
||||
|
||||
Args:
|
||||
model_file (str): path to the model checkpoint.
|
||||
model_config (str): path to the model config file.
|
||||
use_cuda (bool): enable/disable CUDA use.
|
||||
"""
|
||||
self.vocoder_config = load_config(model_config)
|
||||
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio)
|
||||
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
|
||||
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
|
||||
if use_cuda:
|
||||
self.vocoder_model.cuda()
|
||||
|
||||
def split_into_sentences(self, text) -> List[str]:
|
||||
"""Split give text into sentences.
|
||||
|
||||
Args:
|
||||
text (str): input text in string format.
|
||||
|
||||
Returns:
|
||||
List[str]: list of sentences.
|
||||
"""
|
||||
return self.seg.segment(text)
|
||||
|
||||
def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None:
|
||||
"""Save the waveform as a file.
|
||||
|
||||
Args:
|
||||
wav (List[int]): waveform as a list of values.
|
||||
path (str): output path to save the waveform.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
# if tensor convert to numpy
|
||||
if torch.is_tensor(wav):
|
||||
wav = wav.cpu().numpy()
|
||||
if isinstance(wav, list):
|
||||
wav = np.array(wav)
|
||||
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out)
|
||||
|
||||
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
|
||||
output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
|
||||
return output_wav
|
||||
|
||||
def tts(
|
||||
self,
|
||||
text: str = "",
|
||||
speaker_name: str = "",
|
||||
language_name: str = "",
|
||||
speaker_wav=None,
|
||||
style_wav=None,
|
||||
style_text=None,
|
||||
reference_wav=None,
|
||||
reference_speaker_name=None,
|
||||
split_sentences: bool = True,
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
"""🐸 TTS magic. Run all the models and generate speech.
|
||||
|
||||
Args:
|
||||
text (str): input text.
|
||||
speaker_name (str, optional): speaker id for multi-speaker models. Defaults to "".
|
||||
language_name (str, optional): language id for multi-language models. Defaults to "".
|
||||
speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None.
|
||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
|
||||
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
|
||||
reference_speaker_name ([type], optional): speaker id of reference waveform. Defaults to None.
|
||||
split_sentences (bool, optional): split the input text into sentences. Defaults to True.
|
||||
**kwargs: additional arguments to pass to the TTS model.
|
||||
Returns:
|
||||
List[int]: [description]
|
||||
"""
|
||||
start_time = time.time()
|
||||
wavs = []
|
||||
|
||||
if not text and not reference_wav:
|
||||
raise ValueError(
|
||||
"You need to define either `text` (for sythesis) or a `reference_wav` (for voice conversion) to use the Coqui TTS API."
|
||||
)
|
||||
|
||||
if text:
|
||||
sens = [text]
|
||||
if split_sentences:
|
||||
print(" > Text splitted to sentences.")
|
||||
sens = self.split_into_sentences(text)
|
||||
print(sens)
|
||||
|
||||
# handle multi-speaker
|
||||
if "voice_dir" in kwargs:
|
||||
self.voice_dir = kwargs["voice_dir"]
|
||||
kwargs.pop("voice_dir")
|
||||
speaker_embedding = None
|
||||
speaker_id = None
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
||||
if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts":
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the average speaker embedding from the saved d_vectors.
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
|
||||
speaker_name, num_samples=None, randomize=False
|
||||
)
|
||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||
else:
|
||||
# get speaker idx from the speaker name
|
||||
speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name]
|
||||
# handle Neon models with single speaker.
|
||||
elif len(self.tts_model.speaker_manager.name_to_id) == 1:
|
||||
speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0]
|
||||
elif not speaker_name and not speaker_wav:
|
||||
raise ValueError(
|
||||
" [!] Looks like you are using a multi-speaker model. "
|
||||
"You need to define either a `speaker_idx` or a `speaker_wav` to use a multi-speaker model."
|
||||
)
|
||||
else:
|
||||
speaker_embedding = None
|
||||
else:
|
||||
if speaker_name and self.voice_dir is None:
|
||||
raise ValueError(
|
||||
f" [!] Missing speakers.json file path for selecting speaker {speaker_name}."
|
||||
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
||||
)
|
||||
|
||||
# handle multi-lingual
|
||||
language_id = None
|
||||
if self.tts_languages_file or (
|
||||
hasattr(self.tts_model, "language_manager")
|
||||
and self.tts_model.language_manager is not None
|
||||
and not self.tts_config.model == "xtts"
|
||||
):
|
||||
if len(self.tts_model.language_manager.name_to_id) == 1:
|
||||
language_id = list(self.tts_model.language_manager.name_to_id.values())[0]
|
||||
|
||||
elif language_name and isinstance(language_name, str):
|
||||
try:
|
||||
language_id = self.tts_model.language_manager.name_to_id[language_name]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f" [!] Looks like you use a multi-lingual model. "
|
||||
f"Language {language_name} is not in the available languages: "
|
||||
f"{self.tts_model.language_manager.name_to_id.keys()}."
|
||||
) from e
|
||||
|
||||
elif not language_name:
|
||||
raise ValueError(
|
||||
" [!] Look like you use a multi-lingual model. "
|
||||
"You need to define either a `language_name` or a `style_wav` to use a multi-lingual model."
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f" [!] Missing language_ids.json file path for selecting language {language_name}."
|
||||
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
|
||||
)
|
||||
|
||||
# compute a new d_vector from the given clip.
|
||||
if (
|
||||
speaker_wav is not None
|
||||
and self.tts_model.speaker_manager is not None
|
||||
and hasattr(self.tts_model.speaker_manager, "encoder_ap")
|
||||
and self.tts_model.speaker_manager.encoder_ap is not None
|
||||
):
|
||||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
||||
|
||||
vocoder_device = "cpu"
|
||||
use_gl = self.vocoder_model is None
|
||||
if not use_gl:
|
||||
vocoder_device = next(self.vocoder_model.parameters()).device
|
||||
if self.use_cuda:
|
||||
vocoder_device = "cuda"
|
||||
|
||||
if not reference_wav: # not voice conversion
|
||||
for sen in sens:
|
||||
if hasattr(self.tts_model, "synthesize"):
|
||||
outputs = self.tts_model.synthesize(
|
||||
text=sen,
|
||||
config=self.tts_config,
|
||||
speaker_id=speaker_name,
|
||||
voice_dirs=self.voice_dir,
|
||||
d_vector=speaker_embedding,
|
||||
speaker_wav=speaker_wav,
|
||||
language=language_name,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# synthesize voice
|
||||
outputs = synthesis(
|
||||
model=self.tts_model,
|
||||
text=sen,
|
||||
CONFIG=self.tts_config,
|
||||
use_cuda=self.use_cuda,
|
||||
speaker_id=speaker_id,
|
||||
style_wav=style_wav,
|
||||
style_text=style_text,
|
||||
use_griffin_lim=use_gl,
|
||||
d_vector=speaker_embedding,
|
||||
language_id=language_id,
|
||||
)
|
||||
waveform = outputs["wav"]
|
||||
if not use_gl:
|
||||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
||||
# renormalize spectrogram based on vocoder config
|
||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||
# compute scale factor for possible sample rate mismatch
|
||||
scale_factor = [
|
||||
1,
|
||||
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
|
||||
]
|
||||
if scale_factor[1] != 1:
|
||||
print(" > interpolating tts model output.")
|
||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||
else:
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
# run vocoder model
|
||||
# [1, T, C]
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
|
||||
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl:
|
||||
waveform = waveform.cpu()
|
||||
if not use_gl:
|
||||
waveform = waveform.numpy()
|
||||
waveform = waveform.squeeze()
|
||||
|
||||
# trim silence
|
||||
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
|
||||
waveform = trim_silence(waveform, self.tts_model.ap)
|
||||
|
||||
wavs += list(waveform)
|
||||
wavs += [0] * 10000
|
||||
else:
|
||||
# get the speaker embedding or speaker id for the reference wav file
|
||||
reference_speaker_embedding = None
|
||||
reference_speaker_id = None
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
||||
if reference_speaker_name and isinstance(reference_speaker_name, str):
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the speaker embedding from the saved d_vectors.
|
||||
reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name(
|
||||
reference_speaker_name
|
||||
)[0]
|
||||
reference_speaker_embedding = np.array(reference_speaker_embedding)[
|
||||
None, :
|
||||
] # [1 x embedding_dim]
|
||||
else:
|
||||
# get speaker idx from the speaker name
|
||||
reference_speaker_id = self.tts_model.speaker_manager.name_to_id[reference_speaker_name]
|
||||
else:
|
||||
reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(
|
||||
reference_wav
|
||||
)
|
||||
outputs = transfer_voice(
|
||||
model=self.tts_model,
|
||||
CONFIG=self.tts_config,
|
||||
use_cuda=self.use_cuda,
|
||||
reference_wav=reference_wav,
|
||||
speaker_id=speaker_id,
|
||||
d_vector=speaker_embedding,
|
||||
use_griffin_lim=use_gl,
|
||||
reference_speaker_id=reference_speaker_id,
|
||||
reference_d_vector=reference_speaker_embedding,
|
||||
)
|
||||
waveform = outputs
|
||||
if not use_gl:
|
||||
mel_postnet_spec = outputs[0].detach().cpu().numpy()
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
||||
# renormalize spectrogram based on vocoder config
|
||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||
# compute scale factor for possible sample rate mismatch
|
||||
scale_factor = [
|
||||
1,
|
||||
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
|
||||
]
|
||||
if scale_factor[1] != 1:
|
||||
print(" > interpolating tts model output.")
|
||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||
else:
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
# run vocoder model
|
||||
# [1, T, C]
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
|
||||
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"):
|
||||
waveform = waveform.cpu()
|
||||
if not use_gl:
|
||||
waveform = waveform.numpy()
|
||||
wavs = waveform.squeeze()
|
||||
|
||||
# compute stats
|
||||
process_time = time.time() - start_time
|
||||
audio_time = len(wavs) / self.tts_config.audio["sample_rate"]
|
||||
print(f" > Processing time: {process_time}")
|
||||
print(f" > Real-time factor: {process_time / audio_time}")
|
||||
return wavs
|
||||
@@ -0,0 +1,44 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||||
r"""Check model gradient against unexpected jumps and failures"""
|
||||
skip_flag = False
|
||||
if ignore_stopnet:
|
||||
if not amp_opt_params:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
[param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip
|
||||
)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||||
else:
|
||||
if not amp_opt_params:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||||
|
||||
# compatibility with different torch versions
|
||||
if isinstance(grad_norm, float):
|
||||
if np.isinf(grad_norm):
|
||||
print(" | > Gradient is INF !!")
|
||||
skip_flag = True
|
||||
else:
|
||||
if torch.isinf(grad_norm):
|
||||
print(" | > Gradient is INF !!")
|
||||
skip_flag = True
|
||||
return grad_norm, skip_flag
|
||||
|
||||
|
||||
def gradual_training_scheduler(global_step, config):
|
||||
"""Setup the gradual training schedule wrt number
|
||||
of active GPUs"""
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus == 0:
|
||||
num_gpus = 1
|
||||
new_values = None
|
||||
# we set the scheduling wrt num_gpus
|
||||
for values in config.gradual_training:
|
||||
if global_step * num_gpus >= values[0]:
|
||||
new_values = values
|
||||
return new_values[1], new_values[2]
|
||||
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
def read_audio(path):
|
||||
wav, sr = torchaudio.load(path)
|
||||
|
||||
if wav.size(0) > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
return wav.squeeze(0), sr
|
||||
|
||||
|
||||
def resample_wav(wav, sr, new_sr):
|
||||
wav = wav.unsqueeze(0)
|
||||
transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr)
|
||||
wav = transform(wav)
|
||||
return wav.squeeze(0)
|
||||
|
||||
|
||||
def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False):
|
||||
factor = new_sr / vad_sr
|
||||
new_timestamps = []
|
||||
if just_begging_end and timestamps:
|
||||
# get just the start and end timestamps
|
||||
new_dict = {"start": int(timestamps[0]["start"] * factor), "end": int(timestamps[-1]["end"] * factor)}
|
||||
new_timestamps.append(new_dict)
|
||||
else:
|
||||
for ts in timestamps:
|
||||
# map to the new SR
|
||||
new_dict = {"start": int(ts["start"] * factor), "end": int(ts["end"] * factor)}
|
||||
new_timestamps.append(new_dict)
|
||||
|
||||
return new_timestamps
|
||||
|
||||
|
||||
def get_vad_model_and_utils(use_cuda=False, use_onnx=False):
|
||||
model, utils = torch.hub.load(
|
||||
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=use_onnx, force_onnx_cpu=True
|
||||
)
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
|
||||
get_speech_timestamps, save_audio, _, _, collect_chunks = utils
|
||||
return model, get_speech_timestamps, save_audio, collect_chunks
|
||||
|
||||
|
||||
def remove_silence(
|
||||
model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False
|
||||
):
|
||||
# get the VAD model and utils functions
|
||||
model, get_speech_timestamps, _, collect_chunks = model_and_utils
|
||||
|
||||
# read ground truth wav and resample the audio for the VAD
|
||||
try:
|
||||
wav, gt_sample_rate = read_audio(audio_path)
|
||||
except:
|
||||
print(f"> ❗ Failed to read {audio_path}")
|
||||
return None, False
|
||||
|
||||
# if needed, resample the audio for the VAD model
|
||||
if gt_sample_rate != vad_sample_rate:
|
||||
wav_vad = resample_wav(wav, gt_sample_rate, vad_sample_rate)
|
||||
else:
|
||||
wav_vad = wav
|
||||
|
||||
if use_cuda:
|
||||
wav_vad = wav_vad.cuda()
|
||||
|
||||
# get speech timestamps from full audio file
|
||||
speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768)
|
||||
|
||||
# map the current speech_timestamps to the sample rate of the ground truth audio
|
||||
new_speech_timestamps = map_timestamps_to_new_sr(
|
||||
vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end
|
||||
)
|
||||
|
||||
# if have speech timestamps else save the wav
|
||||
if new_speech_timestamps:
|
||||
wav = collect_chunks(new_speech_timestamps, wav)
|
||||
is_speech = True
|
||||
else:
|
||||
print(f"> The file {audio_path} probably does not have speech please check it !!")
|
||||
is_speech = False
|
||||
|
||||
# save
|
||||
torchaudio.save(out_path, wav[None, :], gt_sample_rate)
|
||||
return out_path, is_speech
|
||||
Reference in New Issue
Block a user