diff --git a/TTS/tts/__init__.py b/TTS/tts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TTS/tts/configs/__init__.py b/TTS/tts/configs/__init__.py new file mode 100644 index 0000000..3146ac1 --- /dev/null +++ b/TTS/tts/configs/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os +from inspect import isclass + +# import all files under configs/ +# configs_dir = os.path.dirname(__file__) +# for file in os.listdir(configs_dir): +# path = os.path.join(configs_dir, file) +# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): +# config_name = file[: file.find(".py")] if file.endswith(".py") else file +# module = importlib.import_module("TTS.tts.configs." + config_name) +# for attribute_name in dir(module): +# attribute = getattr(module, attribute_name) + +# if isclass(attribute): +# # Add the class to this package's variables +# globals()[attribute_name] = attribute diff --git a/TTS/tts/configs/align_tts_config.py b/TTS/tts/configs/align_tts_config.py new file mode 100644 index 0000000..317a01a --- /dev/null +++ b/TTS/tts/configs/align_tts_config.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.align_tts import AlignTTSArgs + + +@dataclass +class AlignTTSConfig(BaseTTSConfig): + """Defines parameters for AlignTTS model. + Example: + + >>> from TTS.tts.configs.align_tts_config import AlignTTSConfig + >>> config = AlignTTSConfig() + + Args: + model(str): + Model name used for selecting the right model at initialization. Defaults to `align_tts`. + positional_encoding (bool): + enable / disable positional encoding applied to the encoder output. Defaults to True. + hidden_channels (int): + Base number of hidden channels. Defines all the layers expect ones defined by the specific encoder or decoder + parameters. Defaults to 256. + hidden_channels_dp (int): + Number of hidden channels of the duration predictor's layers. Defaults to 256. + encoder_type (str): + Type of the encoder used by the model. Look at `TTS.tts.layers.feed_forward.encoder` for more details. + Defaults to `fftransformer`. + encoder_params (dict): + Parameters used to define the encoder network. Look at `TTS.tts.layers.feed_forward.encoder` for more details. + Defaults to `{"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}`. + decoder_type (str): + Type of the decoder used by the model. Look at `TTS.tts.layers.feed_forward.decoder` for more details. + Defaults to `fftransformer`. + decoder_params (dict): + Parameters used to define the decoder network. Look at `TTS.tts.layers.feed_forward.decoder` for more details. + Defaults to `{"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}`. + phase_start_steps (List[int]): + A list of number of steps required to start the next training phase. AlignTTS has 4 different training + phases. Thus you need to define 4 different values to enable phase based training. If None, it + trains the whole model together. Defaults to None. + ssim_alpha (float): + Weight for the SSIM loss. If set <= 0, disables the SSIM loss. Defaults to 1.0. + duration_loss_alpha (float): + Weight for the duration predictor's loss. Defaults to 1.0. + mdn_alpha (float): + Weight for the MDN loss. Defaults to 1.0. + spec_loss_alpha (float): + Weight for the MSE spectrogram loss. If set <= 0, disables the L1 loss. Defaults to 1.0. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage.""" + + model: str = "align_tts" + # model specific params + model_args: AlignTTSArgs = field(default_factory=AlignTTSArgs) + phase_start_steps: List[int] = None + + ssim_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + mdn_alpha: float = 1.0 + + # multi-speaker settings + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = None + lr_scheduler_params: dict = None + lr: float = 1e-4 + grad_clip: float = 5.0 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py new file mode 100644 index 0000000..4d1cd13 --- /dev/null +++ b/TTS/tts/configs/bark_config.py @@ -0,0 +1,105 @@ +import os +from dataclasses import dataclass, field +from typing import Dict + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.layers.bark.model import GPTConfig +from TTS.tts.layers.bark.model_fine import FineGPTConfig +from TTS.tts.models.bark import BarkAudioConfig +from TTS.utils.generic_utils import get_user_data_dir + + +@dataclass +class BarkConfig(BaseTTSConfig): + """Bark TTS configuration + + Args: + model (str): model name that registers the model. + audio (BarkAudioConfig): audio configuration. Defaults to BarkAudioConfig(). + num_chars (int): number of characters in the alphabet. Defaults to 0. + semantic_config (GPTConfig): semantic configuration. Defaults to GPTConfig(). + fine_config (FineGPTConfig): fine configuration. Defaults to FineGPTConfig(). + coarse_config (GPTConfig): coarse configuration. Defaults to GPTConfig(). + CONTEXT_WINDOW_SIZE (int): GPT context window size. Defaults to 1024. + SEMANTIC_RATE_HZ (float): semantic tokens rate in Hz. Defaults to 49.9. + SEMANTIC_VOCAB_SIZE (int): semantic vocabulary size. Defaults to 10_000. + CODEBOOK_SIZE (int): encodec codebook size. Defaults to 1024. + N_COARSE_CODEBOOKS (int): number of coarse codebooks. Defaults to 2. + N_FINE_CODEBOOKS (int): number of fine codebooks. Defaults to 8. + COARSE_RATE_HZ (int): coarse tokens rate in Hz. Defaults to 75. + SAMPLE_RATE (int): sample rate. Defaults to 24_000. + USE_SMALLER_MODELS (bool): use smaller models. Defaults to False. + TEXT_ENCODING_OFFSET (int): text encoding offset. Defaults to 10_048. + SEMANTIC_PAD_TOKEN (int): semantic pad token. Defaults to 10_000. + TEXT_PAD_TOKEN ([type]): text pad token. Defaults to 10_048. + TEXT_EOS_TOKEN ([type]): text end of sentence token. Defaults to 10_049. + TEXT_SOS_TOKEN ([type]): text start of sentence token. Defaults to 10_050. + SEMANTIC_INFER_TOKEN (int): semantic infer token. Defaults to 10_051. + COARSE_SEMANTIC_PAD_TOKEN (int): coarse semantic pad token. Defaults to 12_048. + COARSE_INFER_TOKEN (int): coarse infer token. Defaults to 12_050. + REMOTE_BASE_URL ([type]): remote base url. Defaults to "https://huggingface.co/erogol/bark/tree". + REMOTE_MODEL_PATHS (Dict): remote model paths. Defaults to None. + LOCAL_MODEL_PATHS (Dict): local model paths. Defaults to None. + SMALL_REMOTE_MODEL_PATHS (Dict): small remote model paths. Defaults to None. + CACHE_DIR (str): local cache directory. Defaults to get_user_data_dir(). + DEF_SPEAKER_DIR (str): default speaker directory to stoke speaker values for voice cloning. Defaults to get_user_data_dir(). + """ + + model: str = "bark" + audio: BarkAudioConfig = field(default_factory=BarkAudioConfig) + num_chars: int = 0 + semantic_config: GPTConfig = field(default_factory=GPTConfig) + fine_config: FineGPTConfig = field(default_factory=FineGPTConfig) + coarse_config: GPTConfig = field(default_factory=GPTConfig) + CONTEXT_WINDOW_SIZE: int = 1024 + SEMANTIC_RATE_HZ: float = 49.9 + SEMANTIC_VOCAB_SIZE: int = 10_000 + CODEBOOK_SIZE: int = 1024 + N_COARSE_CODEBOOKS: int = 2 + N_FINE_CODEBOOKS: int = 8 + COARSE_RATE_HZ: int = 75 + SAMPLE_RATE: int = 24_000 + USE_SMALLER_MODELS: bool = False + + TEXT_ENCODING_OFFSET: int = 10_048 + SEMANTIC_PAD_TOKEN: int = 10_000 + TEXT_PAD_TOKEN: int = 129_595 + SEMANTIC_INFER_TOKEN: int = 129_599 + COARSE_SEMANTIC_PAD_TOKEN: int = 12_048 + COARSE_INFER_TOKEN: int = 12_050 + + REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/" + REMOTE_MODEL_PATHS: Dict = None + LOCAL_MODEL_PATHS: Dict = None + SMALL_REMOTE_MODEL_PATHS: Dict = None + CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0")) + DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers")) + + def __post_init__(self): + self.REMOTE_MODEL_PATHS = { + "text": { + "path": os.path.join(self.REMOTE_BASE_URL, "text_2.pt"), + "checksum": "54afa89d65e318d4f5f80e8e8799026a", + }, + "coarse": { + "path": os.path.join(self.REMOTE_BASE_URL, "coarse_2.pt"), + "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", + }, + "fine": { + "path": os.path.join(self.REMOTE_BASE_URL, "fine_2.pt"), + "checksum": "59d184ed44e3650774a2f0503a48a97b", + }, + } + self.LOCAL_MODEL_PATHS = { + "text": os.path.join(self.CACHE_DIR, "text_2.pt"), + "coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"), + "fine": os.path.join(self.CACHE_DIR, "fine_2.pt"), + "hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"), + "hubert": os.path.join(self.CACHE_DIR, "hubert.pt"), + } + self.SMALL_REMOTE_MODEL_PATHS = { + "text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")}, + "coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")}, + "fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")}, + } + self.sample_rate = self.SAMPLE_RATE # pylint: disable=attribute-defined-outside-init diff --git a/TTS/tts/configs/delightful_tts_config.py b/TTS/tts/configs/delightful_tts_config.py new file mode 100644 index 0000000..805d995 --- /dev/null +++ b/TTS/tts/configs/delightful_tts_config.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTtsAudioConfig, VocoderConfig + + +@dataclass +class DelightfulTTSConfig(BaseTTSConfig): + """ + Configuration class for the DelightfulTTS model. + + Attributes: + model (str): Name of the model ("delightful_tts"). + audio (DelightfulTtsAudioConfig): Configuration for audio settings. + model_args (DelightfulTtsArgs): Configuration for model arguments. + use_attn_priors (bool): Whether to use attention priors. + vocoder (VocoderConfig): Configuration for the vocoder. + init_discriminator (bool): Whether to initialize the discriminator. + steps_to_start_discriminator (int): Number of steps to start the discriminator. + grad_clip (List[float]): Gradient clipping values. + lr_gen (float): Learning rate for the gan generator. + lr_disc (float): Learning rate for the gan discriminator. + lr_scheduler_gen (str): Name of the learning rate scheduler for the generator. + lr_scheduler_gen_params (dict): Parameters for the learning rate scheduler for the generator. + lr_scheduler_disc (str): Name of the learning rate scheduler for the discriminator. + lr_scheduler_disc_params (dict): Parameters for the learning rate scheduler for the discriminator. + scheduler_after_epoch (bool): Whether to schedule after each epoch. + optimizer (str): Name of the optimizer. + optimizer_params (dict): Parameters for the optimizer. + ssim_loss_alpha (float): Alpha value for the SSIM loss. + mel_loss_alpha (float): Alpha value for the mel loss. + aligner_loss_alpha (float): Alpha value for the aligner loss. + pitch_loss_alpha (float): Alpha value for the pitch loss. + energy_loss_alpha (float): Alpha value for the energy loss. + u_prosody_loss_alpha (float): Alpha value for the utterance prosody loss. + p_prosody_loss_alpha (float): Alpha value for the phoneme prosody loss. + dur_loss_alpha (float): Alpha value for the duration loss. + char_dur_loss_alpha (float): Alpha value for the character duration loss. + binary_align_loss_alpha (float): Alpha value for the binary alignment loss. + binary_loss_warmup_epochs (int): Number of warm-up epochs for the binary loss. + disc_loss_alpha (float): Alpha value for the discriminator loss. + gen_loss_alpha (float): Alpha value for the generator loss. + feat_loss_alpha (float): Alpha value for the feature loss. + vocoder_mel_loss_alpha (float): Alpha value for the vocoder mel loss. + multi_scale_stft_loss_alpha (float): Alpha value for the multi-scale STFT loss. + multi_scale_stft_loss_params (dict): Parameters for the multi-scale STFT loss. + return_wav (bool): Whether to return audio waveforms. + use_weighted_sampler (bool): Whether to use a weighted sampler. + weighted_sampler_attrs (dict): Attributes for the weighted sampler. + weighted_sampler_multipliers (dict): Multipliers for the weighted sampler. + r (int): Value for the `r` override. + compute_f0 (bool): Whether to compute F0 values. + f0_cache_path (str): Path to the F0 cache. + attn_prior_cache_path (str): Path to the attention prior cache. + num_speakers (int): Number of speakers. + use_speaker_embedding (bool): Whether to use speaker embedding. + speakers_file (str): Path to the speaker file. + speaker_embedding_channels (int): Number of channels for the speaker embedding. + language_ids_file (str): Path to the language IDs file. + """ + + model: str = "delightful_tts" + + # model specific params + audio: DelightfulTtsAudioConfig = field(default_factory=DelightfulTtsAudioConfig) + model_args: DelightfulTtsArgs = field(default_factory=DelightfulTtsArgs) + use_attn_priors: bool = True + + # vocoder + vocoder: VocoderConfig = field(default_factory=VocoderConfig) + init_discriminator: bool = True + + # optimizer + steps_to_start_discriminator: int = 200000 + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # acoustic model loss params + ssim_loss_alpha: float = 1.0 + mel_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 1.0 + energy_loss_alpha: float = 1.0 + u_prosody_loss_alpha: float = 0.5 + p_prosody_loss_alpha: float = 0.5 + dur_loss_alpha: float = 1.0 + char_dur_loss_alpha: float = 0.01 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 10 + + # vocoder loss params + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + vocoder_mel_loss_alpha: float = 10.0 + multi_scale_stft_loss_alpha: float = 2.5 + multi_scale_stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # data loader params + return_wav: bool = True + use_weighted_sampler: bool = False + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + attn_prior_cache_path: str = None + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: str = None + d_vector_dim: int = None + + # testing + test_sentences: List[List[str]] = field( + default_factory=lambda: [ + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py new file mode 100644 index 0000000..d086d26 --- /dev/null +++ b/TTS/tts/configs/fast_pitch_config.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class FastPitchConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastPitch model. + + Example: + + >>> from TTS.tts.configs.fast_pitch_config import FastPitchConfig + >>> config = FastPitchConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_align_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + + # dataset configs + compute_f0(bool): + Compute pitch. defaults to True + + f0_cache_path(str): + pith cache path. defaults to None + """ + + model: str = "fast_pitch" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py new file mode 100644 index 0000000..af6c2db --- /dev/null +++ b/TTS/tts/configs/fast_speech_config.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class FastSpeechConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastSpeech model. + + Example: + + >>> from TTS.tts.configs.fast_speech_config import FastSpeechConfig + >>> config = FastSpeechConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastSpeechArgs` for more details. Defaults to `FastSpeechArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_speech" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=False)) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/fastspeech2_config.py b/TTS/tts/configs/fastspeech2_config.py new file mode 100644 index 0000000..d179617 --- /dev/null +++ b/TTS/tts/configs/fastspeech2_config.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class Fastspeech2Config(BaseTTSConfig): + """Configure `ForwardTTS` as FastPitch model. + + Example: + + >>> from TTS.tts.configs.fastspeech2_config import FastSpeech2Config + >>> config = FastSpeech2Config() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + energy_loss_alpha (float): + Weight for the energy predictor's loss. If set 0, disables the energy predictor. Defaults to 1.0. + + binary_align_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + + # dataset configs + compute_f0(bool): + Compute pitch. defaults to True + + f0_cache_path(str): + pith cache path. defaults to None + + # dataset configs + compute_energy(bool): + Compute energy. defaults to True + + energy_cache_path(str): + energy cache path. defaults to None + """ + + model: str = "fastspeech2" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=True, use_energy=True)) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.1 + energy_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # dataset configs + compute_energy: bool = True + energy_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/glow_tts_config.py b/TTS/tts/configs/glow_tts_config.py new file mode 100644 index 0000000..f42f3e5 --- /dev/null +++ b/TTS/tts/configs/glow_tts_config.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class GlowTTSConfig(BaseTTSConfig): + """Defines parameters for GlowTTS model. + + Example: + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> config = GlowTTSConfig() + + Args: + model(str): + Model name used for selecting the right model at initialization. Defaults to `glow_tts`. + encoder_type (str): + Type of the encoder used by the model. Look at `TTS.tts.layers.glow_tts.encoder` for more details. + Defaults to `rel_pos_transformers`. + encoder_params (dict): + Parameters used to define the encoder network. Look at `TTS.tts.layers.glow_tts.encoder` for more details. + Defaults to `{"kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "num_heads": 2, "hidden_channels_ffn": 768}` + use_encoder_prenet (bool): + enable / disable the use of a prenet for the encoder. Defaults to True. + hidden_channels_enc (int): + Number of base hidden channels used by the encoder network. It defines the input and the output channel sizes, + and for some encoder types internal hidden channels sizes too. Defaults to 192. + hidden_channels_dec (int): + Number of base hidden channels used by the decoder WaveNet network. Defaults to 192 as in the original work. + hidden_channels_dp (int): + Number of layer channels of the duration predictor network. Defaults to 256 as in the original work. + mean_only (bool): + If true predict only the mean values by the decoder flow. Defaults to True. + out_channels (int): + Number of channels of the model output tensor. Defaults to 80. + num_flow_blocks_dec (int): + Number of decoder blocks. Defaults to 12. + inference_noise_scale (float): + Noise scale used at inference. Defaults to 0.33. + kernel_size_dec (int): + Decoder kernel size. Defaults to 5 + dilation_rate (int): + Rate to increase dilation by each layer in a decoder block. Defaults to 1. + num_block_layers (int): + Number of decoder layers in each decoder block. Defaults to 4. + dropout_p_dec (float): + Dropout rate for decoder. Defaults to 0.1. + num_speaker (int): + Number of speaker to define the size of speaker embedding layer. Defaults to 0. + c_in_channels (int): + Number of speaker embedding channels. It is set to 512 if embeddings are learned. Defaults to 0. + num_splits (int): + Number of split levels in inversible conv1x1 operation. Defaults to 4. + num_squeeze (int): + Number of squeeze levels. When squeezing channels increases and time steps reduces by the factor + 'num_squeeze'. Defaults to 2. + sigmoid_scale (bool): + enable/disable sigmoid scaling in decoder. Defaults to False. + mean_only (bool): + If True, encoder only computes mean value and uses constant variance for each time step. Defaults to true. + encoder_type (str): + Encoder module type. Possible values are`["rel_pos_transformer", "gated_conv", "residual_conv_bn", "time_depth_separable"]` + Check `TTS.tts.layers.glow_tts.encoder` for more details. Defaults to `rel_pos_transformers` as in the original paper. + encoder_params (dict): + Encoder module parameters. Defaults to None. + d_vector_dim (int): + Channels of external speaker embedding vectors. Defaults to 0. + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + style_wav_for_test (str): + Path to the wav file used for changing the style of the speech. Defaults to None. + inference_noise_scale (float): + Variance used for sampling the random noise added to the decoder's input at inference. Defaults to 0.0. + length_scale (float): + Multiply the predicted durations with this value to change the speech speed. Defaults to 1. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "glow_tts" + + # model params + num_chars: int = None + encoder_type: str = "rel_pos_transformer" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + } + ) + use_encoder_prenet: bool = True + hidden_channels_enc: int = 192 + hidden_channels_dec: int = 192 + hidden_channels_dp: int = 256 + dropout_p_dp: float = 0.1 + dropout_p_dec: float = 0.05 + mean_only: bool = True + out_channels: int = 80 + num_flow_blocks_dec: int = 12 + inference_noise_scale: float = 0.33 + kernel_size_dec: int = 5 + dilation_rate: int = 1 + num_block_layers: int = 4 + num_speakers: int = 0 + c_in_channels: int = 0 + num_splits: int = 4 + num_squeeze: int = 2 + sigmoid_scale: bool = False + encoder_type: str = "rel_pos_transformer" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + "input_length": None, + } + ) + d_vector_dim: int = 0 + + # training params + data_dep_init_steps: int = 10 + + # inference params + style_wav_for_test: str = None + inference_noise_scale: float = 0.0 + length_scale: float = 1.0 + + # multi-speaker settings + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_file: str = False + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + grad_clip: float = 5.0 + lr: float = 1e-3 + + # overrides + min_seq_len: int = 3 + max_seq_len: int = 500 + r: int = 1 # DO NOT CHANGE - TODO: make this immutable once coqpit implements it. + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/neuralhmm_tts_config.py b/TTS/tts/configs/neuralhmm_tts_config.py new file mode 100644 index 0000000..50f7284 --- /dev/null +++ b/TTS/tts/configs/neuralhmm_tts_config.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class NeuralhmmTTSConfig(BaseTTSConfig): + """ + Define parameters for Neural HMM TTS model. + + Example: + + >>> from TTS.tts.configs.overflow_config import OverflowConfig + >>> config = OverflowConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Overflow`. + run_eval_steps (int): + Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None. + save_step (int): + Save local checkpoint every save_step steps. Defaults to 500. + plot_step (int): + Plot training stats on the logger every plot_step steps. Defaults to 1. + model_param_stats (bool): + Log model parameters stats on the logger dashboard. Defaults to False. + force_generate_statistics (bool): + Force generate mel normalization statistics. Defaults to False. + mel_statistics_parameter_path (str): + Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics. + Defaults to None. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + state_per_phone (int): + Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2. + encoder_in_out_features (int): + Channels of encoder input and character embedding tensors. Defaults to 512. + encoder_n_convolutions (int): + Number of convolution layers in the encoder. Defaults to 3. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + ar_order (int): + Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + sampling_temp (float): + Variation added to the sample from the latent space of neural HMM. Defaults to 0.334. + deterministic_transition (bool): + deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + duration_threshold (float): + Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate. + use_grad_checkpointing (bool): + Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True. + max_sampling_time (int): + Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dim (int): + Dimension of the Prenet. Defaults to 256. + prenet_n_layers (int): + Number of layers in the Prenet. Defaults to 2. + prenet_dropout (float): + Dropout rate of the Prenet. Defaults to 0.5. + prenet_dropout_at_inference (bool): + Use dropout at inference time. Defaults to False. + memory_rnn_dim (int): + Dimension of the memory LSTM to process the prenet output. Defaults to 1024. + outputnet_size (list[int]): + Size of the output network inside the neural HMM. Defaults to [1024]. + flat_start_params (dict): + Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`. + It will be recomputed when you pass the dataset. + std_floor (float): + Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01. + It is called `variance flooring` in standard HMM literature. + optimizer (str): + Optimizer to use for training. Defaults to `adam`. + optimizer_params (dict): + Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`. + grad_clip (float): + Gradient clipping threshold. Defaults to 40_000. + lr (float): + Learning rate. Defaults to 1e-3. + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `None`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "NeuralHMM_TTS" + + # Training and Checkpoint configs + run_eval_steps: int = 100 + save_step: int = 500 + plot_step: int = 1 + model_param_stats: bool = False + + # data parameters + force_generate_statistics: bool = False + mel_statistics_parameter_path: str = None + + # Encoder parameters + num_chars: int = None + state_per_phone: int = 2 + encoder_in_out_features: int = 512 + encoder_n_convolutions: int = 3 + + # HMM parameters + out_channels: int = 80 + ar_order: int = 1 + sampling_temp: float = 0 + deterministic_transition: bool = True + duration_threshold: float = 0.43 + use_grad_checkpointing: bool = True + max_sampling_time: int = 1000 + + ## Prenet parameters + prenet_type: str = "original" + prenet_dim: int = 256 + prenet_n_layers: int = 2 + prenet_dropout: float = 0.5 + prenet_dropout_at_inference: bool = True + memory_rnn_dim: int = 1024 + + ## Outputnet parameters + outputnet_size: List[int] = field(default_factory=lambda: [1024]) + flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14}) + std_floor: float = 0.001 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6}) + grad_clip: float = 40000.0 + lr: float = 1e-3 + lr_scheduler: str = None + + # overrides + min_text_len: int = 10 + max_text_len: int = 500 + min_audio_len: int = 512 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "Be a voice, not an echo.", + ] + ) + + # Extra needed config + r: int = 1 + use_d_vector_file: bool = False + use_speaker_embedding: bool = False + + def check_values(self): + """Validate the hyperparameters. + + Raises: + AssertionError: when the parameters network is not defined + AssertionError: transition probability is not between 0 and 1 + """ + assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model." + assert ( + len(self.outputnet_size) >= 1 + ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}" + assert ( + 0 < self.flat_start_params["transition_p"] < 1 + ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}" diff --git a/TTS/tts/configs/overflow_config.py b/TTS/tts/configs/overflow_config.py new file mode 100644 index 0000000..dc3e554 --- /dev/null +++ b/TTS/tts/configs/overflow_config.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class OverflowConfig(BaseTTSConfig): # The classname has to be camel case + """ + Define parameters for OverFlow model. + + Example: + + >>> from TTS.tts.configs.overflow_config import OverflowConfig + >>> config = OverflowConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Overflow`. + run_eval_steps (int): + Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None. + save_step (int): + Save local checkpoint every save_step steps. Defaults to 500. + plot_step (int): + Plot training stats on the logger every plot_step steps. Defaults to 1. + model_param_stats (bool): + Log model parameters stats on the logger dashboard. Defaults to False. + force_generate_statistics (bool): + Force generate mel normalization statistics. Defaults to False. + mel_statistics_parameter_path (str): + Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics. + Defaults to None. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + state_per_phone (int): + Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2. + encoder_in_out_features (int): + Channels of encoder input and character embedding tensors. Defaults to 512. + encoder_n_convolutions (int): + Number of convolution layers in the encoder. Defaults to 3. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + ar_order (int): + Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + sampling_temp (float): + Variation added to the sample from the latent space of neural HMM. Defaults to 0.334. + deterministic_transition (bool): + deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + duration_threshold (float): + Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate. + use_grad_checkpointing (bool): + Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True. + max_sampling_time (int): + Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dim (int): + Dimension of the Prenet. Defaults to 256. + prenet_n_layers (int): + Number of layers in the Prenet. Defaults to 2. + prenet_dropout (float): + Dropout rate of the Prenet. Defaults to 0.5. + prenet_dropout_at_inference (bool): + Use dropout at inference time. Defaults to False. + memory_rnn_dim (int): + Dimension of the memory LSTM to process the prenet output. Defaults to 1024. + outputnet_size (list[int]): + Size of the output network inside the neural HMM. Defaults to [1024]. + flat_start_params (dict): + Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`. + It will be recomputed when you pass the dataset. + std_floor (float): + Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01. + It is called `variance flooring` in standard HMM literature. + hidden_channels_dec (int): + Number of base hidden channels used by the decoder WaveNet network. Defaults to 150. + kernel_size_dec (int): + Decoder kernel size. Defaults to 5 + dilation_rate (int): + Rate to increase dilation by each layer in a decoder block. Defaults to 1. + num_flow_blocks_dec (int): + Number of decoder layers in each decoder block. Defaults to 4. + dropout_p_dec (float): + Dropout rate of the decoder. Defaults to 0.05. + num_splits (int): + Number of split levels in inversible conv1x1 operation. Defaults to 4. + num_squeeze (int): + Number of squeeze levels. When squeezing channels increases and time steps reduces by the factor + 'num_squeeze'. Defaults to 2. + sigmoid_scale (bool): + enable/disable sigmoid scaling in decoder. Defaults to False. + c_in_channels (int): + Unused parameter from GlowTTS's decoder. Defaults to 0. + optimizer (str): + Optimizer to use for training. Defaults to `adam`. + optimizer_params (dict): + Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`. + grad_clip (float): + Gradient clipping threshold. Defaults to 40_000. + lr (float): + Learning rate. Defaults to 1e-3. + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `None`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "Overflow" + + # Training and Checkpoint configs + run_eval_steps: int = 100 + save_step: int = 500 + plot_step: int = 1 + model_param_stats: bool = False + + # data parameters + force_generate_statistics: bool = False + mel_statistics_parameter_path: str = None + + # Encoder parameters + num_chars: int = None + state_per_phone: int = 2 + encoder_in_out_features: int = 512 + encoder_n_convolutions: int = 3 + + # HMM parameters + out_channels: int = 80 + ar_order: int = 1 + sampling_temp: float = 0.334 + deterministic_transition: bool = True + duration_threshold: float = 0.55 + use_grad_checkpointing: bool = True + max_sampling_time: int = 1000 + + ## Prenet parameters + prenet_type: str = "original" + prenet_dim: int = 256 + prenet_n_layers: int = 2 + prenet_dropout: float = 0.5 + prenet_dropout_at_inference: bool = False + memory_rnn_dim: int = 1024 + + ## Outputnet parameters + outputnet_size: List[int] = field(default_factory=lambda: [1024]) + flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14}) + std_floor: float = 0.01 + + # Decoder parameters + hidden_channels_dec: int = 150 + kernel_size_dec: int = 5 + dilation_rate: int = 1 + num_flow_blocks_dec: int = 12 + num_block_layers: int = 4 + dropout_p_dec: float = 0.05 + num_splits: int = 4 + num_squeeze: int = 2 + sigmoid_scale: bool = False + c_in_channels: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6}) + grad_clip: float = 40000.0 + lr: float = 1e-3 + lr_scheduler: str = None + + # overrides + min_text_len: int = 10 + max_text_len: int = 500 + min_audio_len: int = 512 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "Be a voice, not an echo.", + ] + ) + + # Extra needed config + r: int = 1 + use_d_vector_file: bool = False + use_speaker_embedding: bool = False + + def check_values(self): + """Validate the hyperparameters. + + Raises: + AssertionError: when the parameters network is not defined + AssertionError: transition probability is not between 0 and 1 + """ + assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model." + assert ( + len(self.outputnet_size) >= 1 + ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}" + assert ( + 0 < self.flat_start_params["transition_p"] < 1 + ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}" diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py new file mode 100644 index 0000000..bf17322 --- /dev/null +++ b/TTS/tts/configs/shared_configs.py @@ -0,0 +1,344 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import Coqpit, check_argument + +from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class GSTConfig(Coqpit): + """Defines the Global Style Token Module + + Args: + gst_style_input_wav (str): + Path to the wav file used to define the style of the output speech at inference. Defaults to None. + + gst_style_input_weights (dict): + Defines the weights for each style token used at inference. Defaults to None. + + gst_embedding_dim (int): + Defines the size of the GST embedding vector dimensions. Defaults to 256. + + gst_num_heads (int): + Number of attention heads used by the multi-head attention. Defaults to 4. + + gst_num_style_tokens (int): + Number of style token vectors. Defaults to 10. + """ + + gst_style_input_wav: str = None + gst_style_input_weights: dict = None + gst_embedding_dim: int = 256 + gst_use_speaker_embedding: bool = False + gst_num_heads: int = 4 + gst_num_style_tokens: int = 10 + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + super().check_values() + check_argument("gst_style_input_weights", c, restricted=False) + check_argument("gst_style_input_wav", c, restricted=False) + check_argument("gst_embedding_dim", c, restricted=True, min_val=0, max_val=1000) + check_argument("gst_use_speaker_embedding", c, restricted=False) + check_argument("gst_num_heads", c, restricted=True, min_val=2, max_val=10) + check_argument("gst_num_style_tokens", c, restricted=True, min_val=1, max_val=1000) + + +@dataclass +class CapacitronVAEConfig(Coqpit): + """Defines the capacitron VAE Module + Args: + capacitron_capacity (int): + Defines the variational capacity limit of the prosody embeddings. Defaults to 150. + capacitron_VAE_embedding_dim (int): + Defines the size of the Capacitron embedding vector dimension. Defaults to 128. + capacitron_use_text_summary_embeddings (bool): + If True, use a text summary embedding in Capacitron. Defaults to True. + capacitron_text_summary_embedding_dim (int): + Defines the size of the capacitron text embedding vector dimension. Defaults to 128. + capacitron_use_speaker_embedding (bool): + if True use speaker embeddings in Capacitron. Defaults to False. + capacitron_VAE_loss_alpha (float): + Weight for the VAE loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + capacitron_grad_clip (float): + Gradient clipping value for all gradients except beta. Defaults to 5.0 + """ + + capacitron_loss_alpha: int = 1 + capacitron_capacity: int = 150 + capacitron_VAE_embedding_dim: int = 128 + capacitron_use_text_summary_embeddings: bool = True + capacitron_text_summary_embedding_dim: int = 128 + capacitron_use_speaker_embedding: bool = False + capacitron_VAE_loss_alpha: float = 0.25 + capacitron_grad_clip: float = 5.0 + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + super().check_values() + check_argument("capacitron_capacity", c, restricted=True, min_val=10, max_val=500) + check_argument("capacitron_VAE_embedding_dim", c, restricted=True, min_val=16, max_val=1024) + check_argument("capacitron_use_speaker_embedding", c, restricted=False) + check_argument("capacitron_text_summary_embedding_dim", c, restricted=False, min_val=16, max_val=512) + check_argument("capacitron_VAE_loss_alpha", c, restricted=False) + check_argument("capacitron_grad_clip", c, restricted=False) + + +@dataclass +class CharactersConfig(Coqpit): + """Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses. + + Args: + characters_class (str): + Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on + the configuration. Defaults to None. + + vocab_dict (dict): + Defines the vocabulary dictionary used to encode the characters. Defaults to None. + + pad (str): + characters in place of empty padding. Defaults to None. + + eos (str): + characters showing the end of a sentence. Defaults to None. + + bos (str): + characters showing the beginning of a sentence. Defaults to None. + + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + + characters (str): + character set used by the model. Characters not in this list are ignored when converting input text to + a list of sequence IDs. Defaults to None. + + punctuations (str): + characters considered as punctuation as parsing the input sentence. Defaults to None. + + phonemes (str): + characters considered as parsing phonemes. This is only for backwards compat. Use `characters` for new + models. Defaults to None. + + is_unique (bool): + remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old + models trained with character lists with duplicates. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + characters_class: str = None + + # using BaseVocabulary + vocab_dict: Dict = None + + # using on BaseCharacters + pad: str = None + eos: str = None + bos: str = None + blank: str = None + characters: str = None + punctuations: str = None + phonemes: str = None + is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates + is_sorted: bool = True + + +@dataclass +class BaseTTSConfig(BaseTrainingConfig): + """Shared parameters among all the tts models. + + Args: + + audio (BaseAudioConfig): + Audio processor config object instance. + + use_phonemes (bool): + enable / disable phoneme use. + + phonemizer (str): + Name of the phonemizer to use. If set None, the phonemizer will be selected by `phoneme_language`. + Defaults to None. + + phoneme_language (str): + Language code for the phonemizer. You can check the list of supported languages by running + `python TTS/tts/utils/text/phonemizers/__init__.py`. Defaults to None. + + compute_input_seq_cache (bool): + enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of + the training, It allows faster data loader time and precise limitation with `max_seq_len` and + `min_seq_len`. + + text_cleaner (str): + Name of the text cleaner used for cleaning and formatting transcripts. + + enable_eos_bos_chars (bool): + enable / disable the use of eos and bos characters. + + test_senteces_file (str): + Path to a txt file that has sentences used at test time. The file must have a sentence per line. + + phoneme_cache_path (str): + Path to the output folder caching the computed phonemes for each sample. + + characters (CharactersConfig): + Instance of a CharactersConfig class. + + batch_group_size (int): + Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence + length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to + prevent using the same batches for each epoch. + + loss_masking (bool): + enable / disable masking loss values against padded segments of samples in a batch. + + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. + + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). + + compute_f0 (int): + (Not in use yet). + + compute_energy (int): + (Not in use yet). + + compute_linear_spec (bool): + If True data loader computes and returns linear spectrograms alongside the other data. + + precompute_num_workers (int): + Number of workers to precompute features. Defaults to 0. + + use_noise_augment (bool): + Augment the input audio with random noise. + + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + + shuffle (bool): + If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True. + + drop_last (bool): + If True, the data loader will drop the last batch if it is not complete. It helps to prevent + issues that emerge from the partial batch statistics. Defaults to True. + + add_blank (bool): + Add blank characters between each other two characters. It improves performance for some models at expense + of slower run-time due to the longer input sequence. + + datasets (List[BaseDatasetConfig]): + List of datasets used for training. If multiple datasets are provided, they are merged and used together + for training. + + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to ``. + + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to ``. + + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + + test_sentences (List[str]): + List of sentences to be used at testing. Defaults to '[]' + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. + + use_length_weighted_sampler (bool): + Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided + into 10 buckets considering the min and max audio of the dataset. The sampler weights will be + computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```. + + length_weighted_sampler_alpha (float): + Number that control the influence of the length sampler weights. Defaults to ```1.0```. + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # phoneme settings + use_phonemes: bool = False + phonemizer: str = None + phoneme_language: str = None + compute_input_seq_cache: bool = False + text_cleaner: str = None + enable_eos_bos_chars: bool = False + test_sentences_file: str = "" + phoneme_cache_path: str = None + # vocabulary parameters + characters: CharactersConfig = None + add_blank: bool = False + # training params + batch_group_size: int = 0 + loss_masking: bool = None + # dataloading + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") + compute_f0: bool = False + compute_energy: bool = False + compute_linear_spec: bool = False + precompute_num_workers: int = 0 + use_noise_augment: bool = False + start_by_longest: bool = False + shuffle: bool = False + drop_last: bool = False + # dataset + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # optimizer + optimizer: str = "radam" + optimizer_params: dict = None + # scheduler + lr_scheduler: str = None + lr_scheduler_params: dict = field(default_factory=lambda: {}) + # testing + test_sentences: List[str] = field(default_factory=lambda: []) + # evaluation + eval_split_max_size: int = None + eval_split_size: float = 0.01 + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 + use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0 + use_length_weighted_sampler: bool = False + length_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py new file mode 100644 index 0000000..bf8517d --- /dev/null +++ b/TTS/tts/configs/speedy_speech_config.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class SpeedySpeechConfig(BaseTTSConfig): + """Configure `ForwardTTS` as SpeedySpeech model. + + Example: + + >>> from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig + >>> config = SpeedySpeechConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `speedy_speech`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `RAdam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `l1`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `huber`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "speedy_speech" + base_model: str = "forward_tts" + + # set model args as SpeedySpeech + model_args: ForwardTTSArgs = field( + default_factory=lambda: ForwardTTSArgs( + use_pitch=False, + encoder_type="residual_conv_bn", + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13, + }, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + out_channels=80, + hidden_channels=128, + positional_encoding=True, + detach_duration_predictor=True, + ) + ) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "l1" + duration_loss_type: str = "huber" + use_ssim_loss: bool = False + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 0.3 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/tacotron2_config.py b/TTS/tts/configs/tacotron2_config.py new file mode 100644 index 0000000..95b6520 --- /dev/null +++ b/TTS/tts/configs/tacotron2_config.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from TTS.tts.configs.tacotron_config import TacotronConfig + + +@dataclass +class Tacotron2Config(TacotronConfig): + """Defines parameters for Tacotron2 based models. + + Example: + + >>> from TTS.tts.configs.tacotron2_config import Tacotron2Config + >>> config = Tacotron2Config() + + Check `TacotronConfig` for argument descriptions. + """ + + model: str = "tacotron2" + out_channels: int = 80 + encoder_in_features: int = 512 + decoder_in_features: int = 512 diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py new file mode 100644 index 0000000..350b5ea --- /dev/null +++ b/TTS/tts/configs/tacotron_config.py @@ -0,0 +1,235 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig, CapacitronVAEConfig, GSTConfig + + +@dataclass +class TacotronConfig(BaseTTSConfig): + """Defines parameters for Tacotron based models. + + Example: + + >>> from TTS.tts.configs.tacotron_config import TacotronConfig + >>> config = TacotronConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Tacotron`. + use_gst (bool): + enable / disable the use of Global Style Token modules. Defaults to False. + gst (GSTConfig): + Instance of `GSTConfig` class. + gst_style_input (str): + Path to the wav file used at inference to set the speech style through GST. If `GST` is enabled and + this is not defined, the model uses a zero vector as an input. Defaults to None. + use_capacitron_vae (bool): + enable / disable the use of Capacitron modules. Defaults to False. + capacitron_vae (CapacitronConfig): + Instance of `CapacitronConfig` class. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + num_speakers (int): + Number of speakers for multi-speaker models. Defaults to 1. + r (int): + Initial number of output frames that the decoder computed per iteration. Larger values makes training and inference + faster but reduces the quality of the output frames. This must be equal to the largest `r` value used in + `gradual_training` schedule. Defaults to 1. + gradual_training (List[List]): + Parameters for the gradual training schedule. It is in the form `[[a, b, c], [d ,e ,f] ..]` where `a` is + the step number to start using the rest of the values, `b` is the `r` value and `c` is the batch size. + If sets None, no gradual training is used. Defaults to None. + memory_size (int): + Defines the number of previous frames used by the Prenet. If set to < 0, then it uses only the last frame. + Defaults to -1. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dropout (bool): + enables / disables the use of dropout in the Prenet. Defaults to True. + prenet_dropout_at_inference (bool): + enable / disable the use of dropout in the Prenet at the inference time. Defaults to False. + stopnet (bool): + enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True. + stopnet_pos_weight (float): + Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with + datasets with longer sentences. Defaults to 0.2. + max_decoder_steps (int): + Max number of steps allowed for the decoder. Defaults to 50. + encoder_in_features (int): + Channels of encoder input and character embedding tensors. Defaults to 256. + decoder_in_features (int): + Channels of decoder input and encoder output tensors. Defaults to 256. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + separate_stopnet (bool): + Use a distinct Stopnet which is trained separately from the rest of the model. Defaults to True. + attention_type (str): + attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'. + attention_heads (int): + Number of attention heads for GMM attention. Defaults to 5. + windowing (bool): + It especially useful at inference to keep attention alignment diagonal. Defaults to False. + use_forward_attn (bool): + It is only valid if ```attn_type``` is ```original```. Defaults to False. + forward_attn_mask (bool): + enable/disable extra masking over forward attention. It is useful at inference to prevent + possible attention failures. Defaults to False. + transition_agent (bool): + enable/disable transition agent in forward attention. Defaults to False. + location_attn (bool): + enable/disable location sensitive attention as in the original Tacotron2 paper. + It is only valid if ```attn_type``` is ```original```. Defaults to True. + bidirectional_decoder (bool): + enable/disable bidirectional decoding. Defaults to False. + double_decoder_consistency (bool): + enable/disable double decoder consistency. Defaults to False. + ddc_r (int): + reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this + as a multiple of the `r` value. Defaults to 6. + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to `RAdam`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `NoamLR`. + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + lr (float): + Initial learning rate. Defaults to `1e-4`. + wd (float): + Weight decay coefficient. Defaults to `1e-6`. + grad_clip (float): + Gradient clipping threshold. Defaults to `5`. + seq_len_norm (bool): + enable / disable the sequnce length normalization in the loss functions. If set True, loss of a sample + is divided by the sequence length. Defaults to False. + loss_masking (bool): + enable / disable masking the paddings of the samples in loss computation. Defaults to True. + decoder_loss_alpha (float): + Weight for the decoder loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_loss_alpha (float): + Weight for the postnet loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_diff_spec_alpha (float): + Weight for the postnet differential loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + decoder_diff_spec_alpha (float): + + Weight for the decoder differential loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + decoder_ssim_alpha (float): + Weight for the decoder SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_ssim_alpha (float): + Weight for the postnet SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + ga_alpha (float): + Weight for the guided attention loss. If set less than or equal to zero, it disables the corresponding loss + function. Defaults to 5. + """ + + model: str = "tacotron" + # model_params: TacotronArgs = field(default_factory=lambda: TacotronArgs()) + use_gst: bool = False + gst: GSTConfig = None + gst_style_input: str = None + + use_capacitron_vae: bool = False + capacitron_vae: CapacitronVAEConfig = None + + # model specific params + num_speakers: int = 1 + num_chars: int = 0 + r: int = 2 + gradual_training: List[List[int]] = None + memory_size: int = -1 + prenet_type: str = "original" + prenet_dropout: bool = True + prenet_dropout_at_inference: bool = False + stopnet: bool = True + separate_stopnet: bool = True + stopnet_pos_weight: float = 0.2 + max_decoder_steps: int = 10000 + encoder_in_features: int = 256 + decoder_in_features: int = 256 + decoder_output_dim: int = 80 + out_channels: int = 513 + + # attention layers + attention_type: str = "original" + attention_heads: int = None + attention_norm: str = "sigmoid" + attention_win: bool = False + windowing: bool = False + use_forward_attn: bool = False + forward_attn_mask: bool = False + transition_agent: bool = False + location_attn: bool = True + + # advance methods + bidirectional_decoder: bool = False + double_decoder_consistency: bool = False + ddc_r: int = 6 + + # multi-speaker settings + speakers_file: str = None + use_speaker_embedding: bool = False + speaker_embedding_dim: int = 512 + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = None + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + seq_len_norm: bool = False + loss_masking: bool = True + + # loss params + decoder_loss_alpha: float = 0.25 + postnet_loss_alpha: float = 0.25 + postnet_diff_spec_alpha: float = 0.25 + decoder_diff_spec_alpha: float = 0.25 + decoder_ssim_alpha: float = 0.25 + postnet_ssim_alpha: float = 0.25 + ga_alpha: float = 5.0 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def check_values(self): + if self.gradual_training: + assert ( + self.gradual_training[0][1] == self.r + ), f"[!] the first scheduled gradual training `r` must be equal to the model's `r` value. {self.gradual_training[0][1]} vs {self.r}" + if self.model == "tacotron" and self.audio is not None: + assert self.out_channels == ( + self.audio.fft_size // 2 + 1 + ), f"{self.out_channels} vs {self.audio.fft_size // 2 + 1}" + if self.model == "tacotron2" and self.audio is not None: + assert self.out_channels == self.audio.num_mels diff --git a/TTS/tts/configs/tortoise_config.py b/TTS/tts/configs/tortoise_config.py new file mode 100644 index 0000000..d60e43d --- /dev/null +++ b/TTS/tts/configs/tortoise_config.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass, field + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.tortoise import TortoiseArgs, TortoiseAudioConfig + + +@dataclass +class TortoiseConfig(BaseTTSConfig): + """Defines parameters for Tortoise TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (TortoiseArgs): + Model architecture arguments. Defaults to `TortoiseArgs()`. + + audio (TortoiseAudioConfig): + Audio processing configuration. Defaults to `TortoiseAudioConfig()`. + + model_dir (str): + Path to the folder that has all the Tortoise models. Defaults to None. + + temperature (float): + Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`. + + length_penalty (float): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, + which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), + length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. + + reperation_penalty (float): + The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`. + + top_p (float): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + Defaults to `0.8`. + + cond_free_k (float): + Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`. + + diffusion_temperature (float): + Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + Defaults to `1.0`. + + num_autoregressive_samples (int): + Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + Defaults to `16`. + + diffusion_iterations (int): + Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. Defaults to `30`. + + sampler (str): + Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`. + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.tortoise_config import TortoiseConfig + >>> config = TortoiseConfig() + """ + + model: str = "tortoise" + # model specific params + model_args: TortoiseArgs = field(default_factory=TortoiseArgs) + audio: TortoiseAudioConfig = field(default_factory=TortoiseAudioConfig) + model_dir: str = None + + # settings + temperature: float = 0.2 + length_penalty: float = 1.0 + repetition_penalty: float = 2.0 + top_p: float = 0.8 + cond_free_k: float = 2.0 + diffusion_temperature: float = 1.0 + + # inference params + num_autoregressive_samples: int = 16 + diffusion_iterations: int = 30 + sampler: str = "ddim" diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py new file mode 100644 index 0000000..2d0242b --- /dev/null +++ b/TTS/tts/configs/vits_config.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.vits import VitsArgs, VitsAudioConfig + + +@dataclass +class VitsConfig(BaseTTSConfig): + """Defines parameters for VITS End2End TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (VitsArgs): + Model architecture arguments. Defaults to `VitsArgs()`. + + audio (VitsAudioConfig): + Audio processing configuration. Defaults to `VitsAudioConfig()`. + + grad_clip (List): + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. + + lr_gen (float): + Initial learning rate for the generator. Defaults to 0.0002. + + lr_disc (float): + Initial learning rate for the discriminator. Defaults to 0.0002. + + lr_scheduler_gen (str): + Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_gen_params (dict): + Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + lr_scheduler_disc (str): + Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_disc_params (dict): + Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + scheduler_after_epoch (bool): + If true, step the schedulers after each epoch else after each step. Defaults to `False`. + + optimizer (str): + Name of the optimizer to use with both the generator and the discriminator networks. One of the + `torch.optim.*`. Defaults to `AdamW`. + + kl_loss_alpha (float): + Loss weight for KL loss. Defaults to 1.0. + + disc_loss_alpha (float): + Loss weight for the discriminator loss. Defaults to 1.0. + + gen_loss_alpha (float): + Loss weight for the generator loss. Defaults to 1.0. + + feat_loss_alpha (float): + Loss weight for the feature matching loss. Defaults to 1.0. + + mel_loss_alpha (float): + Loss weight for the mel loss. Defaults to 45.0. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + test_sentences (List[List]): + List of sentences with speaker and language information to be used for testing. + + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> config = VitsConfig() + """ + + model: str = "vits" + # model specific params + model_args: VitsArgs = field(default_factory=VitsArgs) + audio: VitsAudioConfig = field(default_factory=VitsAudioConfig) + + # optimizer + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # loss params + kl_loss_alpha: float = 1.0 + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + mel_loss_alpha: float = 45.0 + dur_loss_alpha: float = 1.0 + speaker_encoder_loss_alpha: float = 1.0 + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # testing + test_sentences: List[List] = field( + default_factory=lambda: [ + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], + ] + ) + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: List[str] = None + d_vector_dim: int = None + + def __post_init__(self): + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py new file mode 100644 index 0000000..bbf048e --- /dev/null +++ b/TTS/tts/configs/xtts_config.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig + + +@dataclass +class XttsConfig(BaseTTSConfig): + """Defines parameters for XTTS TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (XttsArgs): + Model architecture arguments. Defaults to `XttsArgs()`. + + audio (XttsAudioConfig): + Audio processing configuration. Defaults to `XttsAudioConfig()`. + + model_dir (str): + Path to the folder that has all the XTTS models. Defaults to None. + + temperature (float): + Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`. + + length_penalty (float): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, + which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), + length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. + + repetition_penalty (float): + The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`. + + top_p (float): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + Defaults to `0.8`. + + num_gpt_outputs (int): + Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As XTTS is a probabilistic model, more samples means a higher probability of creating something "great". + Defaults to `16`. + + gpt_cond_len (int): + Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`. + + gpt_cond_chunk_len (int): + Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the + latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`. + + max_ref_len (int): + Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`. + + sound_norm_refs (bool): + Whether to normalize the conditioning audio. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.xtts_config import XttsConfig + >>> config = XttsConfig() + """ + + model: str = "xtts" + # model specific params + model_args: XttsArgs = field(default_factory=XttsArgs) + audio: XttsAudioConfig = field(default_factory=XttsAudioConfig) + model_dir: str = None + languages: List[str] = field( + default_factory=lambda: [ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh-cn", + "hu", + "ko", + "ja", + "hi", + ] + ) + + # inference params + temperature: float = 0.85 + length_penalty: float = 1.0 + repetition_penalty: float = 2.0 + top_k: int = 50 + top_p: float = 0.85 + num_gpt_outputs: int = 1 + + # cloning + gpt_cond_len: int = 12 + gpt_cond_chunk_len: int = 4 + max_ref_len: int = 10 + sound_norm_refs: bool = False diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py new file mode 100644 index 0000000..1921385 --- /dev/null +++ b/TTS/tts/datasets/__init__.py @@ -0,0 +1,181 @@ +import os +import sys +from collections import Counter +from pathlib import Path +from typing import Callable, Dict, List, Tuple, Union + +import numpy as np + +from TTS.tts.datasets.dataset import * +from TTS.tts.datasets.formatters import * + + +def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): + """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. + + Args: + items (List[List]): + A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + """ + speakers = [item["speaker_name"] for item in items] + is_multi_speaker = len(set(speakers)) > 1 + if eval_split_size > 1: + eval_split_size = int(eval_split_size) + else: + if eval_split_max_size: + eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size)) + else: + eval_split_size = int(len(items) * eval_split_size) + + assert ( + eval_split_size > 0 + ), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format( + 1 / len(items) + ) + np.random.seed(0) + np.random.shuffle(items) + if is_multi_speaker: + items_eval = [] + speakers = [item["speaker_name"] for item in items] + speaker_counter = Counter(speakers) + while len(items_eval) < eval_split_size: + item_idx = np.random.randint(0, len(items)) + speaker_to_be_removed = items[item_idx]["speaker_name"] + if speaker_counter[speaker_to_be_removed] > 1: + items_eval.append(items[item_idx]) + speaker_counter[speaker_to_be_removed] -= 1 + del items[item_idx] + return items_eval, items + return items[:eval_split_size], items[eval_split_size:] + + +def add_extra_keys(metadata, language, dataset_name): + for item in metadata: + # add language name + item["language"] = language + # add unique audio name + relfilepath = os.path.splitext(os.path.relpath(item["audio_file"], item["root_path"]))[0] + audio_unique_name = f"{dataset_name}#{relfilepath}" + item["audio_unique_name"] = audio_unique_name + return metadata + + +def load_tts_samples( + datasets: Union[List[Dict], Dict], + eval_split=True, + formatter: Callable = None, + eval_split_max_size=None, + eval_split_size=0.01, +) -> Tuple[List[List], List[List]]: + """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. + If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based + on the dataset name. + + Args: + datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are + in the list, they are all merged. + + eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate + an eval split automatically. Defaults to True. + + formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It + must take the root_path and the meta_file name and return a list of samples in the format of + `[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as + example. Defaults to None. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + Returns: + Tuple[List[List], List[List]: training and evaluation splits of the dataset. + """ + meta_data_train_all = [] + meta_data_eval_all = [] if eval_split else None + if not isinstance(datasets, list): + datasets = [datasets] + for dataset in datasets: + formatter_name = dataset["formatter"] + dataset_name = dataset["dataset_name"] + root_path = dataset["path"] + meta_file_train = dataset["meta_file_train"] + meta_file_val = dataset["meta_file_val"] + ignored_speakers = dataset["ignored_speakers"] + language = dataset["language"] + + # setup the right data processor + if formatter is None: + formatter = _get_formatter_by_name(formatter_name) + # load train set + meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) + assert len(meta_data_train) > 0, f" [!] No training samples found in {root_path}/{meta_file_train}" + + meta_data_train = add_extra_keys(meta_data_train, language, dataset_name) + + print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") + # load evaluation split if set + if eval_split: + if meta_file_val: + meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) + meta_data_eval = add_extra_keys(meta_data_eval, language, dataset_name) + else: + eval_size_per_dataset = eval_split_max_size // len(datasets) if eval_split_max_size else None + meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_size_per_dataset, eval_split_size) + meta_data_eval_all += meta_data_eval + meta_data_train_all += meta_data_train + # load attention masks for the duration predictor training + if dataset.meta_file_attn_mask: + meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) + for idx, ins in enumerate(meta_data_train_all): + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_train_all[idx].update({"alignment_file": attn_file}) + if meta_data_eval_all: + for idx, ins in enumerate(meta_data_eval_all): + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_eval_all[idx].update({"alignment_file": attn_file}) + # set none for the next iter + formatter = None + return meta_data_train_all, meta_data_eval_all + + +def load_attention_mask_meta_data(metafile_path): + """Load meta data file created by compute_attention_masks.py""" + with open(metafile_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + meta_data = [] + for line in lines: + wav_file, attn_file = line.split("|") + meta_data.append([wav_file, attn_file]) + return meta_data + + +def _get_formatter_by_name(name): + """Returns the respective preprocessing function.""" + thismodule = sys.modules[__name__] + return getattr(thismodule, name.lower()) + + +def find_unique_chars(data_samples, verbose=True): + texts = "".join(item[0] for item in data_samples) + chars = set(texts) + lower_chars = filter(lambda c: c.islower(), chars) + chars_force_lower = [c.lower() for c in chars] + chars_force_lower = set(chars_force_lower) + + if verbose: + print(f" > Number of unique characters: {len(chars)}") + print(f" > Unique characters: {''.join(sorted(chars))}") + print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") + print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + return chars_force_lower diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py new file mode 100644 index 0000000..19fb25b --- /dev/null +++ b/TTS/tts/datasets/dataset.py @@ -0,0 +1,973 @@ +import base64 +import collections +import os +import random +from typing import Dict, List, Union + +import numpy as np +import torch +import tqdm +from torch.utils.data import Dataset + +from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy + +import mutagen + +# to prevent too many open files error as suggested here +# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +torch.multiprocessing.set_sharing_strategy("file_system") + + +def _parse_sample(item): + language_name = None + attn_file = None + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item + elif len(item) == 4: + text, wav_file, speaker_name, language_name = item + elif len(item) == 3: + text, wav_file, speaker_name = item + else: + raise ValueError(" [!] Dataset cannot parse the sample.") + return text, wav_file, speaker_name, language_name, attn_file + + +def noise_augment_audio(wav): + return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + + +def string2filename(string): + # generate a safe and reversible filename based on a string + filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") + return filename + + +def get_audio_size(audiopath): + extension = audiopath.rpartition(".")[-1].lower() + if extension not in {"mp3", "wav", "flac"}: + raise RuntimeError(f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!") + + audio_info = mutagen.File(audiopath).info + return int(audio_info.length * audio_info.sample_rate) + + +class TTSDataset(Dataset): + def __init__( + self, + outputs_per_step: int = 1, + compute_linear_spec: bool = False, + ap: AudioProcessor = None, + samples: List[Dict] = None, + tokenizer: "TTSTokenizer" = None, + compute_f0: bool = False, + compute_energy: bool = False, + f0_cache_path: str = None, + energy_cache_path: str = None, + return_wav: bool = False, + batch_group_size: int = 0, + min_text_len: int = 0, + max_text_len: int = float("inf"), + min_audio_len: int = 0, + max_audio_len: int = float("inf"), + phoneme_cache_path: str = None, + precompute_num_workers: int = 0, + speaker_id_mapping: Dict = None, + d_vector_mapping: Dict = None, + language_id_mapping: Dict = None, + use_noise_augment: bool = False, + start_by_longest: bool = False, + verbose: bool = False, + ): + """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. + + If you need something different, you can subclass and override. + + Args: + outputs_per_step (int): Number of time frames predicted per step. + + compute_linear_spec (bool): compute linear spectrogram if True. + + ap (TTS.tts.utils.AudioProcessor): Audio processor object. + + samples (list): List of dataset samples. + + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else + use the given. Defaults to None. + + compute_f0 (bool): compute f0 if True. Defaults to False. + + compute_energy (bool): compute energy if True. Defaults to False. + + f0_cache_path (str): Path to store f0 cache. Defaults to None. + + energy_cache_path (str): Path to store energy cache. Defaults to None. + + return_wav (bool): Return the waveform of the sample. Defaults to False. + + batch_group_size (int): Range of batch randomization after sorting + sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a + batch. Set 0 to disable. Defaults to 0. + + min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. + Defaults to 0. + + max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. + Defaults to float("inf"). + + min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. + Defaults to 0. + + max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. + The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to + this value if you encounter an OOM error in training. Defaults to float("inf"). + + phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a + separate file. Defaults to None. + + precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. + + speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the + embedding layer. Defaults to None. + + d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None. + + use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. + + start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. + + verbose (bool): Print diagnostic information. Defaults to false. + """ + super().__init__() + self.batch_group_size = batch_group_size + self._samples = samples + self.outputs_per_step = outputs_per_step + self.compute_linear_spec = compute_linear_spec + self.return_wav = return_wav + self.compute_f0 = compute_f0 + self.compute_energy = compute_energy + self.f0_cache_path = f0_cache_path + self.energy_cache_path = energy_cache_path + self.min_audio_len = min_audio_len + self.max_audio_len = max_audio_len + self.min_text_len = min_text_len + self.max_text_len = max_text_len + self.ap = ap + self.phoneme_cache_path = phoneme_cache_path + self.speaker_id_mapping = speaker_id_mapping + self.d_vector_mapping = d_vector_mapping + self.language_id_mapping = language_id_mapping + self.use_noise_augment = use_noise_augment + self.start_by_longest = start_by_longest + + self.verbose = verbose + self.rescue_item_idx = 1 + self.pitch_computed = False + self.tokenizer = tokenizer + + if self.tokenizer.use_phonemes: + self.phoneme_dataset = PhonemeDataset( + self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers + ) + + if compute_f0: + self.f0_dataset = F0Dataset( + self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + ) + if compute_energy: + self.energy_dataset = EnergyDataset( + self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers + ) + if self.verbose: + self.print_logs() + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = get_audio_size(wav_file) + lens.append(audio_len) + return lens + + @property + def samples(self): + return self._samples + + @samples.setter + def samples(self, new_samples): + self._samples = new_samples + if hasattr(self, "f0_dataset"): + self.f0_dataset.samples = new_samples + if hasattr(self, "energy_dataset"): + self.energy_dataset.samples = new_samples + if hasattr(self, "phoneme_dataset"): + self.phoneme_dataset.samples = new_samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.load_data(idx) + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> DataLoader initialization") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") + + def load_wav(self, filename): + waveform = self.ap.load_wav(filename) + assert waveform.size > 0 + return waveform + + def get_phonemes(self, idx, text): + out_dict = self.phoneme_dataset[idx] + assert text == out_dict["text"], f"{text} != {out_dict['text']}" + assert len(out_dict["token_ids"]) > 0 + return out_dict + + def get_f0(self, idx): + out_dict = self.f0_dataset[idx] + item = self.samples[idx] + assert item["audio_unique_name"] == out_dict["audio_unique_name"] + return out_dict + + def get_energy(self, idx): + out_dict = self.energy_dataset[idx] + item = self.samples[idx] + assert item["audio_unique_name"] == out_dict["audio_unique_name"] + return out_dict + + @staticmethod + def get_attn_mask(attn_file): + return np.load(attn_file) + + def get_token_ids(self, idx, text): + if self.tokenizer.use_phonemes: + token_ids = self.get_phonemes(idx, text)["token_ids"] + else: + token_ids = self.tokenizer.text_to_ids(text) + return np.array(token_ids, dtype=np.int32) + + def load_data(self, idx): + item = self.samples[idx] + + raw_text = item["text"] + + wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) + + # apply noise for augmentation + if self.use_noise_augment: + wav = noise_augment_audio(wav) + + # get token ids + token_ids = self.get_token_ids(idx, item["text"]) + + # get pre-computed attention maps + attn = None + if "alignment_file" in item: + attn = self.get_attn_mask(item["alignment_file"]) + + # after phonemization the text length may change + # this is a shareful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: + self.rescue_item_idx += 1 + return self.load_data(self.rescue_item_idx) + + # get f0 values + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + energy = None + if self.compute_energy: + energy = self.get_energy(idx)["energy"] + + sample = { + "raw_text": raw_text, + "token_ids": token_ids, + "wav": wav, + "pitch": f0, + "energy": energy, + "attn": attn, + "item_idx": item["audio_file"], + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "wav_file_name": os.path.basename(item["audio_file"]), + "audio_unique_name": item["audio_unique_name"], + } + return sample + + @staticmethod + def _compute_lengths(samples): + new_samples = [] + for item in samples: + audio_length = get_audio_size(item["audio_file"]) + text_lenght = len(item["text"]) + item["audio_length"] = audio_length + item["text_length"] = text_lenght + new_samples += [item] + return new_samples + + @staticmethod + def filter_by_length(lengths: List[int], min_len: int, max_len: int): + idxs = np.argsort(lengths) # ascending order + ignore_idx = [] + keep_idx = [] + for idx in idxs: + length = lengths[idx] + if length < min_len or length > max_len: + ignore_idx.append(idx) + else: + keep_idx.append(idx) + return ignore_idx, keep_idx + + @staticmethod + def sort_by_length(samples: List[List]): + audio_lengths = [s["audio_length"] for s in samples] + idxs = np.argsort(audio_lengths) # ascending order + return idxs + + @staticmethod + def create_buckets(samples, batch_group_size: int): + assert batch_group_size > 0 + for i in range(len(samples) // batch_group_size): + offset = i * batch_group_size + end_offset = offset + batch_group_size + temp_items = samples[offset:end_offset] + random.shuffle(temp_items) + samples[offset:end_offset] = temp_items + return samples + + @staticmethod + def _select_samples_by_idx(idxs, samples): + samples_new = [] + for idx in idxs: + samples_new.append(samples[idx]) + return samples_new + + def preprocess_samples(self): + r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length + range. + """ + samples = self._compute_lengths(self.samples) + + # sort items based on the sequence length in ascending order + text_lengths = [i["text_length"] for i in samples] + audio_lengths = [i["audio_length"] for i in samples] + text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) + audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) + keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) + ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) + + samples = self._select_samples_by_idx(keep_idx, samples) + + sorted_idxs = self.sort_by_length(samples) + + if self.start_by_longest: + longest_idxs = sorted_idxs[-1] + sorted_idxs[-1] = sorted_idxs[0] + sorted_idxs[0] = longest_idxs + + samples = self._select_samples_by_idx(sorted_idxs, samples) + + if len(samples) == 0: + raise RuntimeError(" [!] No samples left") + + # shuffle batch groups + # create batches with similar length items + # the larger the `batch_group_size`, the higher the length variety in a batch. + if self.batch_group_size > 0: + samples = self.create_buckets(samples, self.batch_group_size) + + # update items to the new sorted items + audio_lengths = [s["audio_length"] for s in samples] + text_lengths = [s["text_length"] for s in samples] + self.samples = samples + + if self.verbose: + print(" | > Preprocessing samples") + print(" | > Max text length: {}".format(np.max(text_lengths))) + print(" | > Min text length: {}".format(np.min(text_lengths))) + print(" | > Avg text length: {}".format(np.mean(text_lengths))) + print(" | ") + print(" | > Max audio length: {}".format(np.max(audio_lengths))) + print(" | > Min audio length: {}".format(np.min(audio_lengths))) + print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) + print(f" | > Num. instances discarded samples: {len(ignore_idx)}") + print(" | > Batch group size: {}.".format(self.batch_group_size)) + + @staticmethod + def _sort_batch(batch, text_lengths): + """Sort the batch by the input text length for RNN efficiency. + + Args: + batch (Dict): Batch returned by `__getitem__`. + text_lengths (List[int]): Lengths of the input character sequences. + """ + text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) + batch = [batch[idx] for idx in ids_sorted_decreasing] + return batch, text_lengths, ids_sorted_decreasing + + def collate_fn(self, batch): + r""" + Perform preprocessing and create a final data batch: + 1. Sort batch instances by text-length + 2. Convert Audio signal to features. + 3. PAD sequences wrt r. + 4. Load to Torch. + """ + + # Puts each data field into a tensor with outer dimension batch size + if isinstance(batch[0], collections.abc.Mapping): + token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) + + # sort items with text input length for RNN efficiency + batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) + + # convert list of dicts to dict of lists + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + # get language ids from language names + if self.language_id_mapping is not None: + language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] + else: + language_ids = None + # get pre-computed d-vectors + if self.d_vector_mapping is not None: + embedding_keys = list(batch["audio_unique_name"]) + d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys] + else: + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_id_mapping: + speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] + else: + speaker_ids = None + # compute features + mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] + + mel_lengths = [m.shape[1] for m in mel] + + # lengths adjusted by the reduction factor + mel_lengths_adjusted = [ + m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) + if m.shape[1] % self.outputs_per_step + else m.shape[1] + for m in mel + ] + + # compute 'stop token' targets + stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] + + # PAD stop targets + stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) + + # PAD sequences with longest instance in the batch + token_ids = prepare_data(batch["token_ids"]).astype(np.int32) + + # PAD features with longest instance + mel = prepare_tensor(mel, self.outputs_per_step) + + # B x D x T --> B x T x D + mel = mel.transpose(0, 2, 1) + + # convert things to pytorch + token_ids_lengths = torch.LongTensor(token_ids_lengths) + token_ids = torch.LongTensor(token_ids) + mel = torch.FloatTensor(mel).contiguous() + mel_lengths = torch.LongTensor(mel_lengths) + stop_targets = torch.FloatTensor(stop_targets) + + # speaker vectors + if d_vectors is not None: + d_vectors = torch.FloatTensor(d_vectors) + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + # compute linear spectrogram + linear = None + if self.compute_linear_spec: + linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] + linear = prepare_tensor(linear, self.outputs_per_step) + linear = linear.transpose(0, 2, 1) + assert mel.shape[1] == linear.shape[1] + linear = torch.FloatTensor(linear).contiguous() + + # format waveforms + wav_padded = None + if self.return_wav: + wav_lengths = [w.shape[0] for w in batch["wav"]] + max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length + wav_lengths = torch.LongTensor(wav_lengths) + wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) + for i, w in enumerate(batch["wav"]): + mel_length = mel_lengths_adjusted[i] + w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") + w = w[: mel_length * self.ap.hop_length] + wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) + wav_padded.transpose_(1, 2) + + # format F0 + if self.compute_f0: + pitch = prepare_data(batch["pitch"]) + assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" + pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT + else: + pitch = None + # format energy + if self.compute_energy: + energy = prepare_data(batch["energy"]) + assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}" + energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT + else: + energy = None + # format attention masks + attns = None + if batch["attn"][0] is not None: + attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] + for idx, attn in enumerate(attns): + pad2 = mel.shape[1] - attn.shape[1] + pad1 = token_ids.shape[1] - attn.shape[0] + assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" + attn = np.pad(attn, [[0, pad1], [0, pad2]]) + attns[idx] = attn + attns = prepare_tensor(attns, self.outputs_per_step) + attns = torch.FloatTensor(attns).unsqueeze(1) + + return { + "token_id": token_ids, + "token_id_lengths": token_ids_lengths, + "speaker_names": batch["speaker_name"], + "linear": linear, + "mel": mel, + "mel_lengths": mel_lengths, + "stop_targets": stop_targets, + "item_idxs": batch["item_idx"], + "d_vectors": d_vectors, + "speaker_ids": speaker_ids, + "attns": attns, + "waveform": wav_padded, + "raw_text": batch["raw_text"], + "pitch": pitch, + "energy": energy, + "language_ids": language_ids, + "audio_unique_names": batch["audio_unique_name"], + } + + raise TypeError( + ( + "batch must contain tensors, numbers, dicts or lists;\ + found {}".format( + type(batch[0]) + ) + ) + ) + + +class PhonemeDataset(Dataset): + """Phoneme Dataset for converting input text to phonemes and then token IDs + + At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data + loading latency. If `cache_path` is already present, it skips the pre-computation. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + tokenizer (TTSTokenizer): + Tokenizer to convert input text to phonemes. + + cache_path (str): + Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. + + precompute_num_workers (int): + Number of workers used for pre-computing the phonemes. Defaults to 0. + """ + + def __init__( + self, + samples: Union[List[Dict], List[List]], + tokenizer: "TTSTokenizer", + cache_path: str, + precompute_num_workers=0, + ): + self.samples = samples + self.tokenizer = tokenizer + self.cache_path = cache_path + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + + def __getitem__(self, index): + item = self.samples[index] + ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"]) + ph_hat = self.tokenizer.ids_to_text(ids) + return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} + + def __len__(self): + return len(self.samples) + + def compute_or_load(self, file_name, text, language): + """Compute phonemes for the given text. + + If the phonemes are already cached, load them from cache. + """ + file_ext = "_phoneme.npy" + cache_path = os.path.join(self.cache_path, file_name + file_ext) + try: + ids = np.load(cache_path) + except FileNotFoundError: + ids = self.tokenizer.text_to_ids(text, language=language) + np.save(cache_path, ids) + return ids + + def get_pad_id(self): + """Get pad token ID for sequence padding""" + return self.tokenizer.pad_id + + def precompute(self, num_workers=1): + """Precompute phonemes for all samples. + + We use pytorch dataloader because we are lazy. + """ + print("[*] Pre-computing phonemes...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + for _ in dataloder: + pbar.update(batch_size) + + def collate_fn(self, batch): + ids = [item["token_ids"] for item in batch] + ids_lens = [item["token_ids_len"] for item in batch] + texts = [item["text"] for item in batch] + texts_hat = [item["ph_hat"] for item in batch] + ids_lens_max = max(ids_lens) + ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id()) + for i, ids_len in enumerate(ids_lens): + ids_torch[i, :ids_len] = torch.LongTensor(ids[i]) + return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> PhonemeDataset ") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") + + +class F0Dataset: + """F0 Dataset for computing F0 from wav files in CPU + + Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of F0 values if `normalize_f0` is True. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute F0 from wav files. + + cache_path (str): + Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the F0 values. Defaults to 0. + + normalize_f0 (bool): + Whether to normalize F0 values by mean and std. Defaults to True. + """ + + def __init__( + self, + samples: Union[List[List], List[Dict]], + ap: "AudioProcessor", + audio_config=None, # pylint: disable=unused-argument + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + self.samples = samples + self.ap = ap + self.verbose = verbose + self.cache_path = cache_path + self.normalize_f0 = normalize_f0 + self.pad_id = 0.0 + self.mean = None + self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_f0: + self.load_stats(cache_path) + + def __getitem__(self, idx): + item = self.samples[idx] + f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"])) + if self.normalize_f0: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + f0 = self.normalize(f0) + return {"audio_unique_name": item["audio_unique_name"], "f0": f0} + + def __len__(self): + return len(self.samples) + + def precompute(self, num_workers=0): + print("[*] Pre-computing F0s...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_f0 = self.normalize_f0 + self.normalize_f0 = False + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + computed_data = [] + for batch in dataloder: + f0 = batch["f0"] + computed_data.append(f for f in f0) + pbar.update(batch_size) + self.normalize_f0 = normalize_f0 + + if self.normalize_f0: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id + + @staticmethod + def create_pitch_file_path(file_name, cache_path): + pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") + return pitch_file + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def load_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch = pitch - self.mean + pitch = pitch / self.std + pitch[zero_idxs] = 0.0 + return pitch + + def denormalize(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch *= self.std + pitch += self.mean + pitch[zero_idxs] = 0.0 + return pitch + + def compute_or_load(self, wav_file, audio_unique_name): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + def collate_fn(self, batch): + audio_unique_name = [item["audio_unique_name"] for item in batch] + f0s = [item["f0"] for item in batch] + f0_lens = [len(item["f0"]) for item in batch] + f0_lens_max = max(f0_lens) + f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) + for i, f0_len in enumerate(f0_lens): + f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) + return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> F0Dataset ") + print(f"{indent}| > Number of instances : {len(self.samples)}") + + +class EnergyDataset: + """Energy Dataset for computing Energy from wav files in CPU + + Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of Energy values if `normalize_Energy` is True. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute Energy from wav files. + + cache_path (str): + Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the Energy values. Defaults to 0. + + normalize_Energy (bool): + Whether to normalize Energy values by mean and std. Defaults to True. + """ + + def __init__( + self, + samples: Union[List[List], List[Dict]], + ap: "AudioProcessor", + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_energy=True, + ): + self.samples = samples + self.ap = ap + self.verbose = verbose + self.cache_path = cache_path + self.normalize_energy = normalize_energy + self.pad_id = 0.0 + self.mean = None + self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_energy: + self.load_stats(cache_path) + + def __getitem__(self, idx): + item = self.samples[idx] + energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"])) + if self.normalize_energy: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + energy = self.normalize(energy) + return {"audio_unique_name": item["audio_unique_name"], "energy": energy} + + def __len__(self): + return len(self.samples) + + def precompute(self, num_workers=0): + print("[*] Pre-computing energys...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_energy = self.normalize_energy + self.normalize_energy = False + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + computed_data = [] + for batch in dataloder: + energy = batch["energy"] + computed_data.append(e for e in energy) + pbar.update(batch_size) + self.normalize_energy = normalize_energy + + if self.normalize_energy: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + energy_mean, energy_std = self.compute_energy_stats(computed_data) + energy_stats = {"mean": energy_mean, "std": energy_std} + np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id + + @staticmethod + def create_energy_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + energy_file = os.path.join(cache_path, file_name + "_energy.npy") + return energy_file + + @staticmethod + def _compute_and_save_energy(ap, wav_file, energy_file=None): + wav = ap.load_wav(wav_file) + energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length) + if energy_file: + np.save(energy_file, energy) + return energy + + @staticmethod + def compute_energy_stats(energy_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def load_stats(self, cache_path): + stats_path = os.path.join(cache_path, "energy_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, energy): + zero_idxs = np.where(energy == 0.0)[0] + energy = energy - self.mean + energy = energy / self.std + energy[zero_idxs] = 0.0 + return energy + + def denormalize(self, energy): + zero_idxs = np.where(energy == 0.0)[0] + energy *= self.std + energy += self.mean + energy[zero_idxs] = 0.0 + return energy + + def compute_or_load(self, wav_file, audio_unique_name): + """ + compute energy and return a numpy array of energy values + """ + energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path) + if not os.path.exists(energy_file): + energy = self._compute_and_save_energy(self.ap, wav_file, energy_file) + else: + energy = np.load(energy_file) + return energy.astype(np.float32) + + def collate_fn(self, batch): + audio_unique_name = [item["audio_unique_name"] for item in batch] + energys = [item["energy"] for item in batch] + energy_lens = [len(item["energy"]) for item in batch] + energy_lens_max = max(energy_lens) + energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id()) + for i, energy_len in enumerate(energy_lens): + energys_torch[i, :energy_len] = torch.LongTensor(energys[i]) + return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> energyDataset ") + print(f"{indent}| > Number of instances : {len(self.samples)}") diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py new file mode 100644 index 0000000..053444b --- /dev/null +++ b/TTS/tts/datasets/formatters.py @@ -0,0 +1,655 @@ +import os +import re +import xml.etree.ElementTree as ET +from glob import glob +from pathlib import Path +from typing import List + +import pandas as pd +from tqdm import tqdm + +######################## +# DATASETS +######################## + + +def cml_tts(root_path, meta_file, ignored_speakers=None): + """Normalizes the CML-TTS meta data file to TTS format + https://github.com/freds0/CML-TTS-Dataset/""" + filepath = os.path.join(root_path, meta_file) + # ensure there are 4 columns for every line + with open(filepath, "r", encoding="utf8") as f: + lines = f.readlines() + num_cols = len(lines[0].split("|")) # take the first row as reference + for idx, line in enumerate(lines[1:]): + if len(line.split("|")) != num_cols: + print(f" > Missing column in line {idx + 1} -> {line.strip()}") + # load metadata + metadata = pd.read_csv(os.path.join(root_path, meta_file), sep="|") + assert all(x in metadata.columns for x in ["wav_filename", "transcript"]) + client_id = None if "client_id" in metadata.columns else "default" + emotion_name = None if "emotion_name" in metadata.columns else "neutral" + items = [] + not_found_counter = 0 + for row in metadata.itertuples(): + if client_id is None and ignored_speakers is not None and row.client_id in ignored_speakers: + continue + audio_path = os.path.join(root_path, row.wav_filename) + if not os.path.exists(audio_path): + not_found_counter += 1 + continue + items.append( + { + "text": row.transcript, + "audio_file": audio_path, + "speaker_name": client_id if client_id is not None else row.client_id, + "emotion_name": emotion_name if emotion_name is not None else row.emotion_name, + "root_path": root_path, + } + ) + if not_found_counter > 0: + print(f" | > [!] {not_found_counter} files not found") + return items + + +def coqui(root_path, meta_file, ignored_speakers=None): + """Interal dataset formatter.""" + filepath = os.path.join(root_path, meta_file) + # ensure there are 4 columns for every line + with open(filepath, "r", encoding="utf8") as f: + lines = f.readlines() + num_cols = len(lines[0].split("|")) # take the first row as reference + for idx, line in enumerate(lines[1:]): + if len(line.split("|")) != num_cols: + print(f" > Missing column in line {idx + 1} -> {line.strip()}") + # load metadata + metadata = pd.read_csv(os.path.join(root_path, meta_file), sep="|") + assert all(x in metadata.columns for x in ["audio_file", "text"]) + speaker_name = None if "speaker_name" in metadata.columns else "coqui" + emotion_name = None if "emotion_name" in metadata.columns else "neutral" + items = [] + not_found_counter = 0 + for row in metadata.itertuples(): + if speaker_name is None and ignored_speakers is not None and row.speaker_name in ignored_speakers: + continue + audio_path = os.path.join(root_path, row.audio_file) + if not os.path.exists(audio_path): + not_found_counter += 1 + continue + items.append( + { + "text": row.text, + "audio_file": audio_path, + "speaker_name": speaker_name if speaker_name is not None else row.speaker_name, + "emotion_name": emotion_name if emotion_name is not None else row.emotion_name, + "root_path": root_path, + } + ) + if not_found_counter > 0: + print(f" | > [!] {not_found_counter} files not found") + return items + + +def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalize TWEB dataset. + https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset + """ + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "tweb" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("\t") + wav_file = os.path.join(root_path, cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = cols[1].strip() + text = cols[0].strip() + wav_file = os.path.join(root_path, "wavs", wav_file) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, "r", encoding="ISO 8859-1") as ttf: + for line in ttf: + cols = line.strip().split("|") + wav_file = cols[0].strip() + text = cols[1].strip() + folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" + wav_file = os.path.join(root_path, folder_name, wav_file) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def mailabs(root_path, meta_files=None, ignored_speakers=None): + """Normalizes M-AI-Labs meta data files to TTS format + + Args: + root_path (str): root folder of the MAILAB language folder. + meta_files (str): list of meta files to be used in the training. If None, finds all the csv files + recursively. Defaults to None + """ + speaker_regex = re.compile(f"by_book{os.sep}(male|female){os.sep}(?P[^{os.sep}]+){os.sep}") + if not meta_files: + csv_files = glob(root_path + f"{os.sep}**{os.sep}metadata.csv", recursive=True) + else: + csv_files = meta_files + + # meta_files = [f.strip() for f in meta_files.split(",")] + items = [] + for csv_file in csv_files: + if os.path.isfile(csv_file): + txt_file = csv_file + else: + txt_file = os.path.join(root_path, csv_file) + + folder = os.path.dirname(txt_file) + # determine speaker based on folder structure... + speaker_name_match = speaker_regex.search(txt_file) + if speaker_name_match is None: + continue + speaker_name = speaker_name_match.group("speaker_name") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + print(" | > {}".format(csv_file)) + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + if not meta_files: + wav_file = os.path.join(folder, "wavs", cols[0] + ".wav") + else: + wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") + if os.path.isfile(wav_file): + text = cols[1].strip() + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path} + ) + else: + # M-AI-Labs have some missing samples, so just print the warning + print("> File %s does not exist!" % (wav_file)) + return items + + +def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the LJSpeech meta data file to TTS format + https://keithito.com/LJ-Speech-Dataset/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "ljspeech" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the LJSpeech meta data file for TTS testing + https://keithito.com/LJ-Speech-Dataset/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + speaker_id = 0 + for idx, line in enumerate(ttf): + # 2 samples per speaker to avoid eval split issues + if idx % 2 == 0: + speaker_id += 1 + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2] + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path} + ) + return items + + +def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the thorsten meta data file to TTS format + https://github.com/thorstenMueller/deep-learning-german-tts/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "thorsten" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the sam-accenture meta data file to TTS format + https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files""" + xml_file = os.path.join(root_path, "voice_over_recordings", meta_file) + xml_root = ET.parse(xml_file).getroot() + items = [] + speaker_name = "sam_accenture" + for item in xml_root.findall("./fileid"): + text = item.text + wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") + if not os.path.exists(wav_file): + print(f" [!] {wav_file} in metafile does not exist. Skipping...") + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the RUSLAN meta data file to TTS format + https://ruslan-corpus.github.io/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "ruslan" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the CSS10 dataset file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "css10" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the Nancy meta data file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "nancy" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + utt_id = line.split()[1] + text = line[line.find('"') + 1 : line.rfind('"') - 1] + wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def common_voice(root_path, meta_file, ignored_speakers=None): + """Normalize the common voice meta data file to TTS format.""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("client_id"): + continue + cols = line.split("\t") + text = cols[2] + speaker_name = cols[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path} + ) + return items + + +def libri_tts(root_path, meta_files=None, ignored_speakers=None): + """https://ai.google/tools/datasets/libri-tts/""" + items = [] + if not meta_files: + meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True) + else: + if isinstance(meta_files, str): + meta_files = [os.path.join(root_path, meta_files)] + + for meta_file in meta_files: + _meta_file = os.path.basename(meta_file).split(".")[0] + with open(meta_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("\t") + file_name = cols[0] + speaker_name, chapter_id, *_ = cols[0].split("_") + _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}") + wav_file = os.path.join(_root_path, file_name + ".wav") + text = cols[2] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + items.append( + { + "text": text, + "audio_file": wav_file, + "speaker_name": f"LTTS_{speaker_name}", + "root_path": root_path, + } + ) + for item in items: + assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}" + return items + + +def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "turkish-female" + skipped_files = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav") + if not os.path.exists(wav_file): + skipped_files.append(wav_file) + continue + text = cols[1].strip() + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + print(f" [!] {len(skipped_files)} files skipped. They don't exist...") + return items + + +# ToDo: add the dataset link when the dataset is released publicly +def brspeech(root_path, meta_file, ignored_speakers=None): + """BRSpeech 3.0 beta""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("wav_filename"): + continue + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[2] + speaker_id = cols[3] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path}) + return items + + +def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None): + """VCTK dataset v0.92. + + URL: + https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip + + This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```. + It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset. + + mic1: + Audio recorded using an omni-directional microphone (DPA 4035). + Contains very low frequency noises. + This is the same audio released in previous versions of VCTK: + https://doi.org/10.7488/ds/1994 + + mic2: + Audio recorded using a small diaphragm condenser microphone with + very wide bandwidth (Sennheiser MKH 800). + Two speakers, p280 and p315 had technical issues of the audio + recordings using MKH 800. + """ + file_ext = "flac" + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + # p280 has no mic2 recordings + if speaker_id == "p280": + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}") + else: + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") + if os.path.exists(wav_file): + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} + ) + else: + print(f" [!] wav files don't exist - {wav_file}") + return items + + +def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): + """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path} + ) + return items + + +def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-argument + items = [] + speaker_name = "synpaflex" + root_path = os.path.join(root_path, "") + wav_files = glob(f"{root_path}**/*.wav", recursive=True) + for wav_file in wav_files: + if os.sep + "wav" + os.sep in wav_file: + txt_file = wav_file.replace("wav", "txt") + else: + txt_file = os.path.join( + os.path.dirname(wav_file), "txt", os.path.basename(wav_file).replace(".wav", ".txt") + ) + if os.path.exists(txt_file) and os.path.exists(wav_file): + with open(txt_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, ignored_speakers=None): + """ToDo: Refer the paper when available""" + items = [] + split_dir = meta_files + meta_files = glob(f"{os.path.join(root_path, split_dir)}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readline().replace("\n", "") + # ignore sentences that contains digits + if ignore_digits_sentences and any(map(str.isdigit, text)): + continue + wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac") + items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path}) + return items + + +def mls(root_path, meta_files=None, ignored_speakers=None): + """http://www.openslr.org/94/""" + items = [] + with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta: + for line in meta: + file, text = line.split("\t") + text = text[:-1] + speaker, book, *_ = file.split("_") + wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker in ignored_speakers: + continue + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path} + ) + return items + + +# ======================================== VOX CELEB =========================================== +def voxceleb2(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument + """ + :param meta_file Used only for consistency with load_tts_samples api + """ + return _voxcel_x(root_path, meta_file, voxcel_idx="2") + + +def voxceleb1(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument + """ + :param meta_file Used only for consistency with load_tts_samples api + """ + return _voxcel_x(root_path, meta_file, voxcel_idx="1") + + +def _voxcel_x(root_path, meta_file, voxcel_idx): + assert voxcel_idx in ["1", "2"] + expected_count = 148_000 if voxcel_idx == "1" else 1_000_000 + voxceleb_path = Path(root_path) + cache_to = voxceleb_path / f"metafile_voxceleb{voxcel_idx}.csv" + cache_to.parent.mkdir(exist_ok=True) + + # if not exists meta file, crawl recursively for 'wav' files + if meta_file is not None: + with open(str(meta_file), "r", encoding="utf-8") as f: + return [x.strip().split("|") for x in f.readlines()] + + elif not cache_to.exists(): + cnt = 0 + meta_data = [] + wav_files = voxceleb_path.rglob("**/*.wav") + for path in tqdm( + wav_files, + desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.", + total=expected_count, + ): + speaker_id = str(Path(path).parent.parent.stem) + assert speaker_id.startswith("id") + text = None # VoxCel does not provide transciptions, and they are not needed for training the SE + meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n") + cnt += 1 + with open(str(cache_to), "w", encoding="utf-8") as f: + f.write("".join(meta_data)) + if cnt < expected_count: + raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}") + + with open(str(cache_to), "r", encoding="utf-8") as f: + return [x.strip().split("|") for x in f.readlines()] + + +def emotion(root_path, meta_file, ignored_speakers=None): + """Generic emotion dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("file_path"): + continue + cols = line.split(",") + wav_file = os.path.join(root_path, cols[0]) + speaker_id = cols[1] + emotion_id = cols[2].replace("\n", "") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append( + {"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path} + ) + return items + + +def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument + """Normalizes the Baker meta data file to TTS format + + Args: + root_path (str): path to the baker dataset + meta_file (str): name of the meta dataset containing names of wav to select and the transcript of the sentence + Returns: + List[List[str]]: List of (text, wav_path, speaker_name) associated with each sentences + """ + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "baker" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + wav_name, text = line.rstrip("\n").split("|") + wav_path = os.path.join(root_path, "clips_22", wav_name) + items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "kokoro" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2].replace(" ", "") + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def kss(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Korean single-speaker dataset from https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "kss" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[2] # cols[1] => 6월, cols[2] => 유월 + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def bel_tts_formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "bel_tts" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py new file mode 100644 index 0000000..2bd2e5f --- /dev/null +++ b/TTS/tts/models/__init__.py @@ -0,0 +1,14 @@ +from typing import Dict, List, Union + +from TTS.utils.generic_utils import find_module + + +def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": + print(" > Using model: {}".format(config.model)) + # fetch the right model implementation. + if "base_model" in config and config["base_model"] is not None: + MyModel = find_module("TTS.tts.models", config.base_model.lower()) + else: + MyModel = find_module("TTS.tts.models", config.model.lower()) + model = MyModel.init_from_config(config=config, samples=samples) + return model diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py new file mode 100644 index 0000000..b2e51de --- /dev/null +++ b/TTS/tts/models/align_tts.py @@ -0,0 +1,448 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.align_tts.mdn import MDNBlock +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.io import load_fsspec + + +@dataclass +class AlignTTSArgs(Coqpit): + """ + Args: + num_chars (int): + number of unique input to characters + out_channels (int): + number of output tensor channels. It is equal to the expected spectrogram size. + hidden_channels (int): + number of channels in all the model layers. + hidden_channels_ffn (int): + number of channels in transformer's conv layers. + hidden_channels_dp (int): + number of channels in duration predictor network. + num_heads (int): + number of attention heads in transformer networks. + num_transformer_layers (int): + number of layers in encoder and decoder transformer blocks. + dropout_p (int): + dropout rate in transformer layers. + length_scale (int, optional): + coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1. + num_speakers (int, optional): + number of speakers for multi-speaker training. Defaults to 0. + external_c (bool, optional): + enable external speaker embeddings. Defaults to False. + c_in_channels (int, optional): + number of channels in speaker embedding vectors. Defaults to 0. + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 256 + hidden_channels_dp: int = 256 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + length_scale: float = 1.0 + num_speakers: int = 0 + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_dim: int = 0 + + +class AlignTTS(BaseTTS): + """AlignTTS with modified duration predictor. + https://arxiv.org/pdf/2003.01950.pdf + + Encoder -> DurationPredictor -> Decoder + + Check :class:`AlignTTSArgs` for the class arguments. + + Paper Abstract: + Targeting at both high efficiency and performance, we propose AlignTTS to predict the + mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a + sequence of characters, and the duration of each character is determined by a duration predictor.Instead of + adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented + to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s + how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean + option score (MOS), but also a high efficiency which is more than 50 times faster than real-time. + + Note: + Original model uses a separate character embedding layer for duration predictor. However, it causes the + duration predictor to overfit and prevents learning higher level interactions among characters. Therefore, + we predict durations based on encoder outputs which has higher level information about input characters. This + enables training without phases as in the original paper. + + Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture + differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters. + + Examples: + >>> from TTS.tts.configs.align_tts_config import AlignTTSConfig + >>> config = AlignTTSConfig() + >>> model = AlignTTS(config) + + """ + + # pylint: disable=dangerous-default-value + + def __init__( + self, + config: "AlignTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + self.speaker_manager = speaker_manager + self.phase = -1 + self.length_scale = ( + float(config.model_args.length_scale) + if isinstance(config.model_args.length_scale, int) + else config.model_args.length_scale + ) + + self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) + + self.embedded_speaker_dim = 0 + self.init_multispeaker(config) + + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + self.embedded_speaker_dim, + ) + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, + ) + self.duration_predictor = DurationPredictor(config.model_args.hidden_channels_dp) + + self.mod_layer = nn.Conv1d(config.model_args.hidden_channels, config.model_args.hidden_channels, 1) + + self.mdn_block = MDNBlock(config.model_args.hidden_channels, 2 * config.model_args.out_channels) + + if self.embedded_speaker_dim > 0 and self.embedded_speaker_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(self.embedded_speaker_dim, config.model_args.hidden_channels, 1) + + @staticmethod + def compute_log_probs(mu, log_sigma, y): + # pylint: disable=protected-access, c-extension-no-member + y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] + mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + expanded_y, expanded_mu = torch.broadcast_tensors(y, mu) + exponential = -0.5 * torch.mean( + torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1 + ) # B, L, T + logp = exponential - 0.5 * log_sigma.mean(dim=-1) + return logp + + def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask): + # find the max alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + log_p = self.compute_log_probs(mu, log_sigma, y) + # [B, T_en, T_dec] + attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1) + dr_mas = torch.sum(attn, -1) + return dr_mas.squeeze(1), log_p + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Examples:: + - encoder output: [a,b,c,d] + - durations: [1, 3, 2, 1] + + - expanded: [a, b, b, b, c, c, d] + - attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + + return x + g + + def _forward_encoder(self, x, x_lengths, g=None): + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.speaker_embedding(g)) # [B, C, 1] + + if g is not None: + g = g.unsqueeze(-1) + + # [B, T, C] + x_emb = self.emb(x) + # [B, C, T] + x_emb = torch.transpose(x_emb, 1, -1) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + + # encoder pass + o_en = self.encoder(x_emb, x_mask) + + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g + + def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de, attn.transpose(1, 2) + + def _forward_mdn(self, o_en, y, y_lengths, x_mask): + # MAS potentials and alignment + mu, log_sigma = self.mdn_block(o_en) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask) + return dr_mas, mu, log_sigma, logp + + def forward( + self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None + ): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + """ + y = y.transpose(1, 2) + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None + if phase == 0: + # train encoder and MDN + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + attn = self.generate_attn(dr_mas, x_mask, y_mask) + elif phase == 1: + # train decoder + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g) + elif phase == 2: + # train the whole except duration predictor + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + elif phase == 3: + # train duration predictor + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(x, x_mask) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + o_dr_log = o_dr_log.squeeze(1) + else: + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + o_dr_log = o_dr_log.squeeze(1) + dr_mas_log = torch.log(dr_mas + 1).squeeze(1) + outputs = { + "model_outputs": o_de.transpose(1, 2), + "alignments": attn, + "durations_log": o_dr_log, + "durations_mas_log": dr_mas_log, + "mu": mu, + "log_sigma": log_sigma, + "logp": logp, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - g: :math:`[B, C]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + # pad input to prevent dropping the last word + # x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + # o_dr_log = self.duration_predictor(x, x_mask) + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + # duration predictor pass + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn} + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input, self.phase) + loss_dict = criterion( + outputs["logp"], + outputs["model_outputs"], + mel_input, + mel_lengths, + outputs["durations_log"], + outputs["durations_mas_log"], + text_lengths, + phase=self.phase, + ) + + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import AlignTTSLoss # pylint: disable=import-outside-toplevel + + return AlignTTSLoss(self.config) + + @staticmethod + def _set_phase(config, global_step): + """Decide AlignTTS training phase""" + if isinstance(config.phase_start_steps, list): + vals = [i < global_step for i in config.phase_start_steps] + if not True in vals: + phase = 0 + else: + phase = ( + len(config.phase_start_steps) + - [i < global_step for i in config.phase_start_steps][::-1].index(True) + - 1 + ) + else: + phase = None + return phase + + def on_epoch_start(self, trainer): + """Set AlignTTS training phase on epoch start.""" + self.phase = self._set_phase(trainer.config, trainer.total_steps_done) + + @staticmethod + def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (AlignTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return AlignTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py new file mode 100644 index 0000000..e5edffd --- /dev/null +++ b/TTS/tts/models/bark.py @@ -0,0 +1,284 @@ +import os +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from coqpit import Coqpit +from encodec import EncodecModel +from transformers import BertTokenizer + +from TTS.tts.layers.bark.inference_funcs import ( + codec_decode, + generate_coarse, + generate_fine, + generate_text_semantic, + generate_voice, + load_voice, +) +from TTS.tts.layers.bark.load_model import load_model +from TTS.tts.layers.bark.model import GPT +from TTS.tts.layers.bark.model_fine import FineGPT +from TTS.tts.models.base_tts import BaseTTS + + +@dataclass +class BarkAudioConfig(Coqpit): + sample_rate: int = 24000 + output_sample_rate: int = 24000 + + +class Bark(BaseTTS): + def __init__( + self, + config: Coqpit, + tokenizer: BertTokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased"), + ) -> None: + super().__init__(config=config, ap=None, tokenizer=None, speaker_manager=None, language_manager=None) + self.config.num_chars = len(tokenizer) + self.tokenizer = tokenizer + self.semantic_model = GPT(config.semantic_config) + self.coarse_model = GPT(config.coarse_config) + self.fine_model = FineGPT(config.fine_config) + self.encodec = EncodecModel.encodec_model_24khz() + self.encodec.set_target_bandwidth(6.0) + + @property + def device(self): + return next(self.parameters()).device + + def load_bark_models(self): + self.semantic_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["text"], device=self.device, config=self.config, model_type="text" + ) + self.coarse_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["coarse"], + device=self.device, + config=self.config, + model_type="coarse", + ) + self.fine_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["fine"], device=self.device, config=self.config, model_type="fine" + ) + + def train_step( + self, + ): + pass + + def text_to_semantic( + self, + text: str, + history_prompt: Optional[str] = None, + temp: float = 0.7, + base=None, + allow_early_stop=True, + **kwargs, + ): + """Generate semantic array from text. + + Args: + text: text to be turned into audio + history_prompt: history choice for audio cloning + temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy semantic array to be fed into `semantic_to_waveform` + """ + x_semantic = generate_text_semantic( + text, + self, + history_prompt=history_prompt, + temp=temp, + base=base, + allow_early_stop=allow_early_stop, + **kwargs, + ) + return x_semantic + + def semantic_to_waveform( + self, + semantic_tokens: np.ndarray, + history_prompt: Optional[str] = None, + temp: float = 0.7, + base=None, + ): + """Generate audio array from semantic input. + + Args: + semantic_tokens: semantic token output from `text_to_semantic` + history_prompt: history choice for audio cloning + temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy audio array at sample frequency 24khz + """ + x_coarse_gen = generate_coarse( + semantic_tokens, + self, + history_prompt=history_prompt, + temp=temp, + base=base, + ) + x_fine_gen = generate_fine( + x_coarse_gen, + self, + history_prompt=history_prompt, + temp=0.5, + base=base, + ) + audio_arr = codec_decode(x_fine_gen, self) + return audio_arr, x_coarse_gen, x_fine_gen + + def generate_audio( + self, + text: str, + history_prompt: Optional[str] = None, + text_temp: float = 0.7, + waveform_temp: float = 0.7, + base=None, + allow_early_stop=True, + **kwargs, + ): + """Generate audio array from input text. + + Args: + text: text to be turned into audio + history_prompt: history choice for audio cloning + text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy audio array at sample frequency 24khz + """ + x_semantic = self.text_to_semantic( + text, + history_prompt=history_prompt, + temp=text_temp, + base=base, + allow_early_stop=allow_early_stop, + **kwargs, + ) + audio_arr, c, f = self.semantic_to_waveform( + x_semantic, history_prompt=history_prompt, temp=waveform_temp, base=base + ) + return audio_arr, [x_semantic, c, f] + + def generate_voice(self, audio, speaker_id, voice_dir): + """Generate a voice from the given audio and text. + + Args: + audio (str): Path to the audio file. + speaker_id (str): Speaker name. + voice_dir (str): Path to the directory to save the generate voice. + """ + if voice_dir is not None: + voice_dirs = [voice_dir] + try: + _ = load_voice(speaker_id, voice_dirs) + except (KeyError, FileNotFoundError): + output_path = os.path.join(voice_dir, speaker_id + ".npz") + os.makedirs(voice_dir, exist_ok=True) + generate_voice(audio, self, output_path) + + def _set_voice_dirs(self, voice_dirs): + def_voice_dir = None + if isinstance(self.config.DEF_SPEAKER_DIR, str): + os.makedirs(self.config.DEF_SPEAKER_DIR, exist_ok=True) + if os.path.isdir(self.config.DEF_SPEAKER_DIR): + def_voice_dir = self.config.DEF_SPEAKER_DIR + _voice_dirs = [def_voice_dir] if def_voice_dir is not None else [] + if voice_dirs is not None: + if isinstance(voice_dirs, str): + voice_dirs = [voice_dirs] + _voice_dirs = voice_dirs + _voice_dirs + return _voice_dirs + + # TODO: remove config from synthesize + def synthesize( + self, text, config, speaker_id="random", voice_dirs=None, **kwargs + ): # pylint: disable=unused-argument + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (BarkConfig): Config with inference parameters. + speaker_id (str): One of the available speaker names. If `random`, it generates a random speaker. + speaker_wav (str): Path to the speaker audio file for cloning a new voice. It is cloned and saved in + `voice_dirs` with the name `speaker_id`. Defaults to None. + voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None. + **kwargs: Model specific inference settings used by `generate_audio()` and `TTS.tts.layers.bark.inference_funcs.generate_text_semantic(). + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + speaker_id = "random" if speaker_id is None else speaker_id + voice_dirs = self._set_voice_dirs(voice_dirs) + history_prompt = load_voice(self, speaker_id, voice_dirs) + outputs = self.generate_audio(text, history_prompt=history_prompt, **kwargs) + return_dict = { + "wav": outputs[0], + "text_inputs": text, + } + + return return_dict + + def eval_step(self): + ... + + def forward(self): + ... + + def inference(self): + ... + + @staticmethod + def init_from_config(config: "BarkConfig", **kwargs): # pylint: disable=unused-argument + return Bark(config) + + # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint( + self, + config, + checkpoint_dir, + text_model_path=None, + coarse_model_path=None, + fine_model_path=None, + hubert_model_path=None, + hubert_tokenizer_path=None, + eval=False, + strict=True, + **kwargs, + ): + """Load a model checkpoints from a directory. This model is with multiple checkpoint files and it + expects to have all the files to be under the given `checkpoint_dir` with the rigth names. + If eval is True, set the model to eval mode. + + Args: + config (TortoiseConfig): The model config. + checkpoint_dir (str): The directory where the checkpoints are stored. + ar_checkpoint_path (str, optional): The path to the autoregressive checkpoint. Defaults to None. + diff_checkpoint_path (str, optional): The path to the diffusion checkpoint. Defaults to None. + clvp_checkpoint_path (str, optional): The path to the CLVP checkpoint. Defaults to None. + vocoder_checkpoint_path (str, optional): The path to the vocoder checkpoint. Defaults to None. + eval (bool, optional): Whether to set the model to eval mode. Defaults to False. + strict (bool, optional): Whether to load the model strictly. Defaults to True. + """ + text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt") + coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt") + fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt") + hubert_model_path = hubert_model_path or os.path.join(checkpoint_dir, "hubert.pt") + hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth") + + self.config.LOCAL_MODEL_PATHS["text"] = text_model_path + self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path + self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path + self.config.LOCAL_MODEL_PATHS["hubert"] = hubert_model_path + self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path + + self.load_bark_models() + + if eval: + self.eval() diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py new file mode 100644 index 0000000..f38dace --- /dev/null +++ b/TTS/tts/models/base_tacotron.py @@ -0,0 +1,305 @@ +import copy +from abc import abstractmethod +from typing import Dict, Tuple + +import torch +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.losses import TacotronLoss +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec +from TTS.utils.training import gradual_training_scheduler + + +class BaseTacotron(BaseTTS): + """Base class shared by Tacotron and Tacotron2""" + + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields as class attributes + for key in config: + setattr(self, key, config[key]) + + # layers + self.embedding = None + self.encoder = None + self.decoder = None + self.postnet = None + + # init tensors + self.embedded_speakers = None + self.embedded_speakers_projected = None + + # global style token + if self.gst and self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim + self.gst_layer = None + + # Capacitron + if self.capacitron_vae and self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim + self.capacitron_vae_layer = None + + # additional layers + self.decoder_backward = None + self.coarse_decoder = None + + @staticmethod + def _format_aux_input(aux_input: Dict) -> Dict: + """Set missing fields to their default values""" + if aux_input: + return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) + return None + + ############################# + # INIT FUNCTIONS + ############################# + + def _init_backward_decoder(self): + """Init the backward decoder for Forward-Backward decoding.""" + self.decoder_backward = copy.deepcopy(self.decoder) + + def _init_coarse_decoder(self): + """Init the coarse decoder for Double-Decoder Consistency.""" + self.coarse_decoder = copy.deepcopy(self.decoder) + self.coarse_decoder.r_init = self.ddc_r + self.coarse_decoder.set_r(self.ddc_r) + + ############################# + # CORE FUNCTIONS + ############################# + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def inference(self): + pass + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load model checkpoint and set up internals. + + Args: + config (Coqpi): model configuration. + checkpoint_path (str): path to checkpoint file. + eval (bool, optional): whether to load model for evaluation. + cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. + """ + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + # TODO: set r in run-time by taking it from the new config + if "r" in state: + # set r from the state (for compatibility with older checkpoints) + self.decoder.set_r(state["r"]) + elif "config" in state: + # set r from config used at training time (for inference) + self.decoder.set_r(state["config"]["r"]) + else: + # set r from the new config (for new-models) + self.decoder.set_r(config.r) + if eval: + self.eval() + print(f" > Model's reduction rate `r` is set to: {self.decoder.r}") + assert not self.training + + def get_criterion(self) -> nn.Module: + """Get the model criterion used in training.""" + return TacotronLoss(self.config) + + @staticmethod + def init_from_config(config: Coqpit): + """Initialize model from config.""" + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config) + return BaseTacotron(config, ap, tokenizer, speaker_manager) + + ########################## + # TEST AND LOG FUNCTIONS # + ########################## + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + ############################# + # COMMON COMPUTE FUNCTIONS + ############################# + + def compute_masks(self, text_lengths, mel_lengths): + """Compute masks against sequence paddings.""" + # B x T_in_max (boolean) + input_mask = sequence_mask(text_lengths) + output_mask = None + if mel_lengths is not None: + max_len = mel_lengths.max() + r = self.decoder.r + max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len + output_mask = sequence_mask(mel_lengths, max_len=max_len) + return input_mask, output_mask + + def _backward_pass(self, mel_specs, encoder_outputs, mask): + """Run backwards decoder""" + decoder_outputs_b, alignments_b, _ = self.decoder_backward( + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask + ) + decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() + return decoder_outputs_b, alignments_b + + def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): + """Double Decoder Consistency""" + T = mel_specs.shape[1] + if T % self.coarse_decoder.r > 0: + padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) + mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) + decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( + encoder_outputs.detach(), mel_specs, input_mask + ) + # scale_factor = self.decoder.r_init / self.decoder.r + alignments_backward = torch.nn.functional.interpolate( + alignments_backward.transpose(1, 2), + size=alignments.shape[1], + mode="nearest", + ).transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward[:, :T, :] + return decoder_outputs_backward, alignments_backward + + ############################# + # EMBEDDING FUNCTIONS + ############################# + + def compute_gst(self, inputs, style_input, speaker_embedding=None): + """Compute global style token""" + if isinstance(style_input, dict): + # multiply each style token with a weight + query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) + if speaker_embedding is not None: + query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) + + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + # ignore style token and return zero tensor + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + else: + # compute style tokens + gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable + inputs = self._concat_speaker_embedding(inputs, gst_outputs) + return inputs + + def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None): + """Capacitron Variational Autoencoder""" + ( + VAE_outputs, + posterior_distribution, + prior_distribution, + capacitron_beta, + ) = self.capacitron_vae_layer( + reference_mel_info, + text_info, + speaker_embedding, # pylint: disable=not-callable + ) + + VAE_outputs = VAE_outputs.to(inputs.device) + encoder_output = self._concat_speaker_embedding( + inputs, VAE_outputs + ) # concatenate to the output of the basic tacotron encoder + return ( + encoder_output, + posterior_distribution, + prior_distribution, + capacitron_beta, + ) + + @staticmethod + def _add_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = outputs + embedded_speakers_ + return outputs + + @staticmethod + def _concat_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = torch.cat([outputs, embedded_speakers_], dim=-1) + return outputs + + ############################# + # CALLBACKS + ############################# + + def on_epoch_start(self, trainer): + """Callback for setting values wrt gradual training schedule. + + Args: + trainer (TrainerTTS): TTS trainer object that is used to train this model. + """ + if self.gradual_training: + r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config) + trainer.config.r = r + self.decoder.set_r(r) + if trainer.config.bidirectional_decoder: + trainer.model.decoder_backward.set_r(r) + print(f"\n > Number of output frames: {self.decoder.r}") diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py new file mode 100644 index 0000000..7871cc3 --- /dev/null +++ b/TTS/tts/models/base_tts.py @@ -0,0 +1,459 @@ +import os +import random +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper + +from TTS.model import BaseTrainerModel +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.utils.data import get_length_balancer_weights +from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram + +# pylint: skip-file + + +class BaseTTS(BaseTrainerModel): + """Base `tts` class. Every new `tts` model must inherit this. + + It defines common `tts` specific functions on top of `Model` implementation. + """ + + MODEL_TYPE = "tts" + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__() + self.config = config + self.ap = ap + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + self.language_manager = language_manager + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + config_num_chars = ( + self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars + ) + num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + if "characters" in config: + self.config.num_chars = num_chars + if hasattr(self.config, "model_args"): + config.model_args.num_chars = num_chars + self.args = self.config.model_args + else: + self.config = config + self.args = config.model_args + elif "Args" in config.__class__.__name__: + self.args = config + else: + raise ValueError("config must be either a *Config or *Args") + + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. + + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models. + + Args: + config (Coqpit): Model configuration. + """ + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + elif hasattr(config, "num_speakers"): + self.num_speakers = config.num_speakers + + # set ultimate speaker embedding size + if config.use_speaker_embedding or config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if self.speaker_manager is not None: + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if self.language_manager is not None and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } + + def format_batch(self, batch: Dict) -> Dict: + """Generic batch formatting for `TTSDataset`. + + You must override this if you use a custom dataset. + + Args: + batch (Dict): [description] + + Returns: + Dict: [description] + """ + # setup input batch + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] + energy = batch["energy"] + language_ids = batch["language_ids"] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets wrt reduction factor + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "stop_target_lengths": stop_target_lengths, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": float(max_text_length), + "max_spec_length": float(max_spec_length), + "item_idx": item_idx, + "waveform": waveform, + "pitch": pitch, + "energy": energy, + "language_ids": language_ids, + "audio_unique_names": batch["audio_unique_names"], + } + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + + if getattr(config, "use_language_weighted_sampler", False): + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + print(" > Using Language weighted sampler with alpha:", alpha) + weights = get_language_balancer_weights(data_items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + print(" > Using Speaker weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(data_items) * alpha + else: + weights = get_speaker_balancer_weights(data_items) * alpha + + if getattr(config, "use_length_weighted_sampler", False): + alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) + print(" > Using Length weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_length_balancer_weights(data_items) * alpha + else: + weights = get_length_balancer_weights(data_items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + if self.speaker_manager is not None: + if hasattr(config, "model_args"): + speaker_id_mapping = ( + self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None + ) + d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None + else: + speaker_id_mapping = None + d_vector_mapping = None + + # setup multi-lingual attributes + if self.language_manager is not None: + language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None + else: + language_id_mapping = None + + # init dataloader + dataset = TTSDataset( + outputs_per_step=config.r if "r" in config else 1, + compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, + compute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), + compute_energy=config.get("compute_energy", False), + energy_cache_path=config.get("energy_cache_path", None), + samples=samples, + ap=self.ap, + return_wav=config.return_wav if "return_wav" in config else False, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + use_noise_augment=False if is_eval else config.use_noise_augment, + verbose=verbose, + speaker_id_mapping=speaker_id_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + language_id_mapping=language_id_mapping, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=config.shuffle if sampler is None else False, # if there is no other sampler + collate_fn=dataset.collate_fn, + drop_last=config.drop_last, # setting this False might cause issues in AMP training. + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def _get_test_aux_input( + self, + ) -> Dict: + d_vector = None + if self.config.use_d_vector_file: + d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] + d_vector = (random.sample(sorted(d_vector), 1),) + + aux_inputs = { + "speaker_id": None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1), + "d_vector": d_vector, + "style_wav": None, # TODO: handle GST style input + } + return aux_inputs + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + if isinstance(sen, list): + aux_inputs = self.get_aux_input_from_test_sentences(sen) + sen = aux_inputs["text"] + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return test_figures, test_audios + + def on_init_start(self, trainer): + """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" + if self.speaker_manager is not None: + output_path = os.path.join(trainer.output_path, "speakers.pth") + self.speaker_manager.save_ids_to_file(output_path) + trainer.config.speakers_file = output_path + # some models don't have `model_args` set + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.speakers_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `speakers.pth` is saved to {output_path}.") + print(" > `speakers_file` is updated in the config.json.") + + if self.language_manager is not None: + output_path = os.path.join(trainer.output_path, "language_ids.json") + self.language_manager.save_ids_to_file(output_path) + trainer.config.language_ids_file = output_path + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.language_ids_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `language_ids.json` is saved to {output_path}.") + print(" > `language_ids_file` is updated in the config.json.") + + +class BaseTTSE2E(BaseTTS): + def _set_model_args(self, config: Coqpit): + self.config = config + if "Config" in config.__class__.__name__: + num_chars = ( + self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + ) + self.config.model_args.num_chars = num_chars + self.config.num_chars = num_chars + self.args = config.model_args + self.args.num_chars = num_chars + elif "Args" in config.__class__.__name__: + self.args = config + self.args.num_chars = self.args.num_chars + else: + raise ValueError("config must be either a *Config or *Args") diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py new file mode 100644 index 0000000..b1cf886 --- /dev/null +++ b/TTS/tts/models/delightful_tts.py @@ -0,0 +1,1770 @@ +import os +from dataclasses import dataclass, field +from itertools import chain +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torchaudio +from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample +from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel +from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.models.base_tts import BaseTTSE2E +from TTS.tts.utils.helpers import average_over_durations, compute_attn_prior, rand_segments, segment, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_pitch, plot_spectrogram +from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 +from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy +from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy +from TTS.utils.audio.processor import AudioProcessor +from TTS.utils.io import load_fsspec +from TTS.vocoder.layers.losses import MultiScaleSTFTLoss +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + + +def id_to_torch(aux_id, cuda=False): + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id) + if cuda: + return aux_id.cuda() + return aux_id + + +def embedding_to_torch(d_vector, cuda=False): + if d_vector is not None: + d_vector = np.asarray(d_vector) + d_vector = torch.from_numpy(d_vector).float() + d_vector = d_vector.squeeze().unsqueeze(0) + if cuda: + return d_vector.cuda() + return d_vector + + +def numpy_to_torch(np_array, dtype, cuda=False): + if np_array is None: + return None + tensor = torch.as_tensor(np_array, dtype=dtype) + if cuda: + return tensor.cuda() + return tensor + + +def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: + batch_size = lengths.shape[0] + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + return mask + + +def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: + out_list = torch.jit.annotate(List[torch.Tensor], []) + for batch in input_ele: + if len(batch.shape) == 1: + one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0) + else: + one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0) + out_list.append(one_batch_padded) + out_padded = torch.stack(out_list) + return out_padded + + +def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: + return torch.ceil(lens / stride).int() + + +def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: + assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." + return torch.randn(shape) * np.sqrt(2 / shape[1]) + + +# pylint: disable=redefined-outer-name +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +hann_window = {} +mel_basis = {} + + +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + +def load_audio(file_path: str): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load( + file_path, + ) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window # pylint: disable=global-statement + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + return spec + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def wav_to_energy(y, n_fft, hop_length, win_length, center=False): + spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return torch.norm(spec, dim=1, keepdim=True) + + +def name_mel_basis(spec, n_fft, fmax): + n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" + return n_fft_len + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis # pylint: disable=global-statement + mel_basis_key = name_mel_basis(spec, n_fft, fmax) + # pylint: disable=too-many-function-args + if mel_basis_key not in mel_basis: + # pylint: disable=missing-kwoa + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[mel_basis_key], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T_y]` + + Return Shapes: + - spec : :math:`[B,C,T_spec]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + mel_basis_key = name_mel_basis(y, n_fft, fmax) + wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device) + if mel_basis_key not in mel_basis: + # pylint: disable=missing-kwoa + mel = librosa_mel_fn( + sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) # pylint: disable=too-many-function-args + mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[mel_basis_key], spec) + spec = amp_to_db(spec) + return spec + + +############################## +# DATASET +############################## + + +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create balancer weight for torch WeightedSampler""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + +class ForwardTTSE2eF0Dataset(F0Dataset): + """Override F0Dataset to avoid slow computing of pitches""" + + def __init__( + self, + ap, + samples: Union[List[List], List[Dict]], + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + super().__init__( + samples=samples, + ap=ap, + verbose=verbose, + cache_path=cache_path, + precompute_num_workers=precompute_num_workers, + normalize_f0=normalize_f0, + ) + + def _compute_and_save_pitch(self, wav_file, pitch_file=None): + wav, _ = load_audio(wav_file) + f0 = compute_f0( + x=wav.numpy()[0], + sample_rate=self.ap.sample_rate, + hop_length=self.ap.hop_length, + pitch_fmax=self.ap.pitch_fmax, + pitch_fmin=self.ap.pitch_fmin, + win_length=self.ap.win_length, + ) + # skip the last F0 value to align with the spectrogram + if wav.shape[1] % self.ap.hop_length != 0: + f0 = f0[:-1] + if pitch_file: + np.save(pitch_file, f0) + return f0 + + def compute_or_load(self, wav_file, audio_name): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(audio_name, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(wav_file=wav_file, pitch_file=pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + +class ForwardTTSE2eDataset(TTSDataset): + def __init__(self, *args, **kwargs): + # don't init the default F0Dataset in TTSDataset + compute_f0 = kwargs.pop("compute_f0", False) + kwargs["compute_f0"] = False + self.attn_prior_cache_path = kwargs.pop("attn_prior_cache_path") + + super().__init__(*args, **kwargs) + + self.compute_f0 = compute_f0 + self.pad_id = self.tokenizer.characters.pad_id + self.ap = kwargs["ap"] + + if self.compute_f0: + self.f0_dataset = ForwardTTSE2eF0Dataset( + ap=self.ap, + samples=self.samples, + cache_path=kwargs["f0_cache_path"], + precompute_num_workers=kwargs["precompute_num_workers"], + ) + + if self.attn_prior_cache_path is not None: + os.makedirs(self.attn_prior_cache_path, exist_ok=True) + + def __getitem__(self, idx): + item = self.samples[idx] + + rel_wav_path = Path(item["audio_file"]).relative_to(item["root_path"]).with_suffix("") + rel_wav_path = str(rel_wav_path).replace("/", "_") + + raw_text = item["text"] + wav, _ = load_audio(item["audio_file"]) + wav_filename = os.path.basename(item["audio_file"]) + + try: + token_ids = self.get_token_ids(idx, item["text"]) + except: + print(idx, item) + # pylint: disable=raise-missing-from + raise OSError + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + attn_prior = None + if self.attn_prior_cache_path is not None: + attn_prior = self.load_or_compute_attn_prior(token_ids, wav, rel_wav_path) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "pitch": f0, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "attn_prior": attn_prior, + "audio_unique_name": item["audio_unique_name"], + } + + def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path): + """Load or compute and save the attention prior.""" + attn_prior_file = os.path.join(self.attn_prior_cache_path, f"{rel_wav_path}.npy") + # pylint: disable=no-else-return + if os.path.exists(attn_prior_file): + return np.load(attn_prior_file) + else: + token_len = len(token_ids) + mel_len = wav.shape[1] // self.ap.hop_length + attn_prior = compute_attn_prior(token_len, mel_len) + np.save(attn_prior_file, attn_prior) + return attn_prior + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - pitch :math:`[B, T]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + - attn_prior: :math:`[[T_token, T_mel]]` + """ + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + pitch_padded = None + if self.compute_f0: + pitch_lens = [p.shape[0] for p in batch["pitch"]] + pitch_lens = torch.LongTensor(pitch_lens) + pitch_lens_max = torch.max(pitch_lens) + pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max) + pitch_padded = pitch_padded.zero_() + self.pad_id + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + + for i in range(B): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + if self.compute_f0: + pitch = batch["pitch"][i] + pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch) + + return { + "text_input": token_padded, + "text_lengths": token_lens, + "text_rel_lens": token_rel_lens, + "pitch": pitch_padded, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_unique_names": batch["audio_unique_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + "attn_priors": batch["attn_prior"] if batch["attn_prior"][0] is not None else None, + } + + +############################## +# CONFIG DEFINITIONS +############################## + + +@dataclass +class VocoderConfig(Coqpit): + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + use_spectral_norm_discriminator: bool = False + upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4]) + periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + pretrained_model_path: Optional[str] = None + + +@dataclass +class DelightfulTtsAudioConfig(Coqpit): + sample_rate: int = 22050 + hop_length: int = 256 + win_length: int = 1024 + fft_size: int = 1024 + mel_fmin: float = 0.0 + mel_fmax: float = 8000 + num_mels: int = 100 + pitch_fmax: float = 640.0 + pitch_fmin: float = 1.0 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + do_trim_silence: bool = True + trim_db: int = 45 + do_rms_norm: bool = False + db_level: float = None + power: float = 1.5 + griffin_lim_iters: int = 60 + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + min_level_db: int = -100 + max_norm: float = 4.0 + + +@dataclass +class DelightfulTtsArgs(Coqpit): + num_chars: int = 100 + spec_segment_size: int = 32 + n_hidden_conformer_encoder: int = 512 + n_layers_conformer_encoder: int = 6 + n_heads_conformer_encoder: int = 8 + dropout_conformer_encoder: float = 0.1 + kernel_size_conv_mod_conformer_encoder: int = 7 + kernel_size_depthwise_conformer_encoder: int = 7 + lrelu_slope: float = 0.3 + n_hidden_conformer_decoder: int = 512 + n_layers_conformer_decoder: int = 6 + n_heads_conformer_decoder: int = 8 + dropout_conformer_decoder: float = 0.1 + kernel_size_conv_mod_conformer_decoder: int = 11 + kernel_size_depthwise_conformer_decoder: int = 11 + bottleneck_size_p_reference_encoder: int = 4 + bottleneck_size_u_reference_encoder: int = 512 + ref_enc_filters_reference_encoder = [32, 32, 64, 64, 128, 128] + ref_enc_size_reference_encoder: int = 3 + ref_enc_strides_reference_encoder = [1, 2, 1, 2, 1] + ref_enc_pad_reference_encoder = [1, 1] + ref_enc_gru_size_reference_encoder: int = 32 + ref_attention_dropout_reference_encoder: float = 0.2 + token_num_reference_encoder: int = 32 + predictor_kernel_size_reference_encoder: int = 5 + n_hidden_variance_adaptor: int = 512 + kernel_size_variance_adaptor: int = 5 + dropout_variance_adaptor: float = 0.5 + n_bins_variance_adaptor: int = 256 + emb_kernel_size_variance_adaptor: int = 3 + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: str = None + speaker_embedding_channels: int = 384 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + freeze_vocoder: bool = False + freeze_text_encoder: bool = False + freeze_duration_predictor: bool = False + freeze_pitch_predictor: bool = False + freeze_energy_predictor: bool = False + freeze_basis_vectors_predictor: bool = False + freeze_decoder: bool = False + length_scale: float = 1.0 + + +############################## +# MODEL DEFINITION +############################## +class DelightfulTTS(BaseTTSE2E): + """ + Paper:: + https://arxiv.org/pdf/2110.12612.pdf + + Paper Abstract:: + This paper describes the Microsoft end-to-end neural text to speech (TTS) system: DelightfulTTS for Blizzard Challenge 2021. + The goal of this challenge is to synthesize natural and high-quality speech from text, and we approach this goal in two perspectives: + The first is to directly model and generate waveform in 48 kHz sampling rate, which brings higher perception quality than previous systems + with 16 kHz or 24 kHz sampling rate; The second is to model the variation information in speech through a systematic design, which improves + the prosody and naturalness. Specifically, for 48 kHz modeling, we predict 16 kHz mel-spectrogram in acoustic model, and + propose a vocoder called HiFiNet to directly generate 48 kHz waveform from predicted 16 kHz mel-spectrogram, which can better trade off training + efficiency, modelling stability and voice quality. We model variation information systematically from both explicit (speaker ID, language ID, pitch and duration) and + implicit (utterance-level and phoneme-level prosody) perspectives: 1) For speaker and language ID, we use lookup embedding in training and + inference; 2) For pitch and duration, we extract the values from paired text-speech data in training and use two predictors to predict the values in inference; 3) + For utterance-level and phoneme-level prosody, we use two reference encoders to extract the values in training, and use two separate predictors to predict the values in inference. + Additionally, we introduce an improved Conformer block to better model the local and global dependency in acoustic model. For task SH1, DelightfulTTS achieves 4.17 mean score in MOS test + and 4.35 in SMOS test, which indicates the effectiveness of our proposed system + + + Model training:: + text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg + spec --------^ + + Examples: + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig + >>> config = ForwardTTSE2eConfig() + >>> model = ForwardTTSE2e(config) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + config: Coqpit, + ap, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config=config, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager) + self.ap = ap + + self._set_model_args(config) + self.init_multispeaker(config) + self.binary_loss_weight = None + + self.args.out_channels = self.config.audio.num_mels + self.args.num_mels = self.config.audio.num_mels + self.acoustic_model = AcousticModel(args=self.args, tokenizer=tokenizer, speaker_manager=speaker_manager) + + self.waveform_decoder = HifiganGenerator( + self.config.audio.num_mels, + 1, + self.config.vocoder.resblock_type_decoder, + self.config.vocoder.resblock_dilation_sizes_decoder, + self.config.vocoder.resblock_kernel_sizes_decoder, + self.config.vocoder.upsample_kernel_sizes_decoder, + self.config.vocoder.upsample_initial_channel_decoder, + self.config.vocoder.upsample_rates_decoder, + inference_padding=0, + # cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if self.config.init_discriminator: + self.disc = VitsDiscriminator( + use_spectral_norm=self.config.vocoder.use_spectral_norm_discriminator, + periods=self.config.vocoder.periods_discriminator, + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def energy_scaler(self): + return self.acoustic_model.energy_scaler + + @property + def length_scale(self): + return self.acoustic_model.length_scale + + @length_scale.setter + def length_scale(self, value): + self.acoustic_model.length_scale = value + + @property + def pitch_mean(self): + return self.acoustic_model.pitch_mean + + @pitch_mean.setter + def pitch_mean(self, value): + self.acoustic_model.pitch_mean = value + + @property + def pitch_std(self): + return self.acoustic_model.pitch_std + + @pitch_std.setter + def pitch_std(self, value): + self.acoustic_model.pitch_std = value + + @property + def mel_basis(self): + return build_mel_basis( + sample_rate=self.ap.sample_rate, + fft_size=self.ap.fft_size, + num_mels=self.ap.num_mels, + mel_fmax=self.ap.mel_fmax, + mel_fmin=self.ap.mel_fmin, + ) # pylint: disable=function-redefined + + def init_for_training(self) -> None: + self.train_disc = ( # pylint: disable=attribute-defined-outside-init + self.config.steps_to_start_discriminator <= 0 + ) # pylint: disable=attribute-defined-outside-init + self.update_energy_scaler = True # pylint: disable=attribute-defined-outside-init + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + self.args.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.args.embedded_speaker_dim = self.args.speaker_embedding_channels + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + self.args.embedded_speaker_dim = self.args.d_vector_dim + + def _freeze_layers(self): + if self.args.freeze_vocoder: + for param in self.vocoder.paramseters(): + param.requires_grad = False + + if self.args.freeze_text_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_duration_predictor: + for param in self.durarion_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_pitch_predictor: + for param in self.pitch_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_energy_predictor: + for param in self.energy_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_decoder: + for param in self.decoder.parameters(): + param.requires_grad = False + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + spec_lengths: torch.LongTensor, + spec: torch.FloatTensor, + waveform: torch.FloatTensor, + pitch: torch.FloatTensor = None, + energy: torch.FloatTensor = None, + attn_priors: torch.FloatTensor = None, + d_vectors: torch.FloatTensor = None, + speaker_idx: torch.LongTensor = None, + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + spec_lengths (torch.LongTensor): Spectrogram sequnce lengths. Defaults to None. + spec (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + waveform (torch.FloatTensor): Waveform. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + energy (torch.FloatTensor): Spectral energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. + attn_priors (torch.FloatTentrasor): Attention priors for the aligner network. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - spec_lengths: :math:`[B]` + - spec: :math:`[B, T_max2, C_spec]` + - waveform: :math:`[B, 1, T_max2 * hop_length]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T_max2]` + - energy: :math:`[B, 1, T_max2]` + """ + encoder_outputs = self.acoustic_model( + tokens=x, + src_lens=x_lengths, + mel_lens=spec_lengths, + mels=spec, + pitches=pitch, + energies=energy, + attn_priors=attn_priors, + d_vectors=d_vectors, + speaker_idx=speaker_idx, + ) + + # use mel-spec from the decoder + vocoder_input = encoder_outputs["model_outputs"] # [B, T_max2, C_mel] + + vocoder_input_slices, slice_ids = rand_segments( + x=vocoder_input.transpose(1, 2), + x_lengths=spec_lengths, + segment_size=self.args.spec_segment_size, + let_short_samples=True, + pad_short=True, + ) + if encoder_outputs["spk_emb"] is not None: + g = encoder_outputs["spk_emb"].unsqueeze(-1) + else: + g = None + + vocoder_output = self.waveform_decoder(x=vocoder_input_slices.detach(), g=g) + wav_seg = segment( + waveform, + slice_ids * self.ap.hop_length, + self.args.spec_segment_size * self.ap.hop_length, + pad_short=True, + ) + model_outputs = {**encoder_outputs} + model_outputs["acoustic_model_outputs"] = encoder_outputs["model_outputs"] + model_outputs["model_outputs"] = vocoder_output + model_outputs["waveform_seg"] = wav_seg + model_outputs["slice_ids"] = slice_ids + return model_outputs + + @torch.no_grad() + def inference( + self, x, aux_input={"d_vectors": None, "speaker_ids": None}, pitch_transform=None, energy_transform=None + ): + encoder_outputs = self.acoustic_model.inference( + tokens=x, + d_vectors=aux_input["d_vectors"], + speaker_idx=aux_input["speaker_ids"], + pitch_transform=pitch_transform, + energy_transform=energy_transform, + p_control=None, + d_control=None, + ) + vocoder_input = encoder_outputs["model_outputs"].transpose(1, 2) # [B, T_max2, C_mel] -> [B, C_mel, T_max2] + if encoder_outputs["spk_emb"] is not None: + g = encoder_outputs["spk_emb"].unsqueeze(-1) + else: + g = None + + vocoder_output = self.waveform_decoder(x=vocoder_input, g=g) + model_outputs = {**encoder_outputs} + model_outputs["model_outputs"] = vocoder_output + return model_outputs + + @torch.no_grad() + def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + encoder_outputs = self.acoustic_model.inference( + tokens=x, + d_vectors=aux_input["d_vectors"], + speaker_idx=aux_input["speaker_ids"], + ) + model_outputs = {**encoder_outputs} + return model_outputs + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + if optimizer_idx == 0: + tokens = batch["text_input"] + token_lenghts = batch["text_lengths"] + mel = batch["mel_input"] + mel_lens = batch["mel_lengths"] + waveform = batch["waveform"] # [B, T, C] -> [B, C, T] + pitch = batch["pitch"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_priors = batch["attn_priors"] + energy = batch["energy"] + + # generator pass + outputs = self.forward( + x=tokens, + x_lengths=token_lenghts, + spec_lengths=mel_lens, + spec=mel, + waveform=waveform, + pitch=pitch, + energy=energy, + attn_priors=attn_priors, + d_vectors=d_vectors, + speaker_idx=speaker_ids, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + if self.train_disc: + # compute scores and features + scores_d_fake, _, scores_d_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] + ) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_fake=scores_d_fake, + scores_disc_real=scores_d_real, + ) + return outputs, loss_dict + return None, None + + if optimizer_idx == 1: + mel = batch["mel_input"] + # compute melspec segment + with autocast(enabled=False): + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], self.args.spec_segment_size, pad_short=True + ) + + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.ap.fft_size, + sample_rate=self.ap.sample_rate, + num_mels=self.ap.num_mels, + hop_length=self.ap.hop_length, + win_length=self.ap.win_length, + fmin=self.ap.mel_fmin, + fmax=self.ap.mel_fmax, + center=False, + ) + + scores_d_fake = None + feats_d_fake = None + feats_d_real = None + + if self.train_disc: + # compute discriminator scores and features + scores_d_fake, feats_d_fake, _, feats_d_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=True): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_output=self.model_outputs_cache["acoustic_model_outputs"].transpose(1, 2), + mel_target=batch["mel_input"], + mel_lens=batch["mel_lengths"], + dur_output=self.model_outputs_cache["dr_log_pred"], + dur_target=self.model_outputs_cache["dr_log_target"].detach(), + pitch_output=self.model_outputs_cache["pitch_pred"], + pitch_target=self.model_outputs_cache["pitch_target"], + energy_output=self.model_outputs_cache["energy_pred"], + energy_target=self.model_outputs_cache["energy_target"], + src_lens=batch["text_lengths"], + waveform=self.model_outputs_cache["waveform_seg"], + waveform_hat=self.model_outputs_cache["model_outputs"], + p_prosody_ref=self.model_outputs_cache["p_prosody_ref"], + p_prosody_pred=self.model_outputs_cache["p_prosody_pred"], + u_prosody_ref=self.model_outputs_cache["u_prosody_ref"], + u_prosody_pred=self.model_outputs_cache["u_prosody_pred"], + aligner_logprob=self.model_outputs_cache["aligner_logprob"], + aligner_hard=self.model_outputs_cache["aligner_mas"], + aligner_soft=self.model_outputs_cache["aligner_soft"], + binary_loss_weight=self.binary_loss_weight, + feats_fake=feats_d_fake, + feats_real=feats_d_real, + scores_fake=scores_d_fake, + spec_slice=mel_slice, + spec_slice_hat=mel_slice_hat, + skip_disc=not self.train_disc, + ) + + loss_dict["avg_text_length"] = batch["text_lengths"].float().mean() + loss_dict["avg_mel_length"] = batch["mel_lengths"].float().mean() + loss_dict["avg_text_batch_occupancy"] = ( + batch["text_lengths"].float() / batch["text_lengths"].float().max() + ).mean() + loss_dict["avg_mel_batch_occupancy"] = ( + batch["mel_lengths"].float() / batch["mel_lengths"].float().max() + ).mean() + + return self.model_outputs_cache, loss_dict + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def _log(self, batch, outputs, name_prefix="train"): + figures, audios = {}, {} + + # encoder outputs + model_outputs = outputs[1]["acoustic_model_outputs"] + alignments = outputs[1]["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, None, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec.T, None, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + pitch_avg = abs(outputs[1]["pitch_target"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs[1]["pitch_pred"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot energy figures + energy_avg = abs(outputs[1]["energy_target"][0, 0].data.cpu().numpy()) + energy_avg_hat = abs(outputs[1]["energy_pred"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + energy_figures = { + "energy_ground_truth": plot_avg_pitch(energy_avg, chars, output_fig=False), + "energy_avg_predicted": plot_avg_pitch(energy_avg_hat, chars, output_fig=False), + } + figures.update(energy_figures) + + # plot the attention mask computed from the predicted durations + alignments_hat = outputs[1]["alignments_dp"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + encoder_audio = mel_to_wav_numpy( + mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.mel_basis, **self.config.audio + ) + audios[f"{name_prefix}/encoder_audio"] = encoder_audio + + # vocoder outputs + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + + vocoder_figures = plot_results(y_hat=y_hat, y=y, ap=self.ap, name_prefix=name_prefix) + figures.update(vocoder_figures) + + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios[f"{name_prefix}/vocoder_audio"] = sample_voice + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use, unused-argument + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previous training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav = None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector = None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} + + def plot_outputs(self, text, wav, alignment, outputs): + figures = {} + pitch_avg_pred = outputs["pitch"].cpu() + energy_avg_pred = outputs["energy"].cpu() + spec = wav_to_mel( + y=torch.from_numpy(wav[None, :]), + n_fft=self.ap.fft_size, + sample_rate=self.ap.sample_rate, + num_mels=self.ap.num_mels, + hop_length=self.ap.hop_length, + win_length=self.ap.win_length, + fmin=self.ap.mel_fmin, + fmax=self.ap.mel_fmax, + center=False, + )[0].transpose(0, 1) + pitch = compute_f0( + x=wav[0], + sample_rate=self.ap.sample_rate, + hop_length=self.ap.hop_length, + pitch_fmax=self.ap.pitch_fmax, + ) + input_text = self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(text, language="en")) + input_text = input_text.replace("", "_") + durations = outputs["durations"] + pitch_avg = average_over_durations(torch.from_numpy(pitch)[None, None, :], durations.cpu()) # [1, 1, n_frames] + pitch_avg_pred_denorm = (pitch_avg_pred * self.pitch_std) + self.pitch_mean + figures["alignment"] = plot_alignment(alignment.transpose(1, 2), output_fig=False) + figures["spectrogram"] = plot_spectrogram(spec) + figures["pitch_from_wav"] = plot_pitch(pitch, spec) + figures["pitch_avg_from_wav"] = plot_avg_pitch(pitch_avg.squeeze(), input_text) + figures["pitch_avg_pred"] = plot_avg_pitch(pitch_avg_pred_denorm.squeeze(), input_text) + figures["energy_avg_pred"] = plot_avg_pitch(energy_avg_pred.squeeze(), input_text) + return figures + + def synthesize( + self, + text: str, + speaker_id: str = None, + d_vector: torch.tensor = None, + pitch_transform=None, + **kwargs, + ): # pylint: disable=unused-argument + # TODO: add cloning support with ref_waveform + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=None), + dtype=np.int32, + ) + + # set speaker inputs + _speaker_id = None + if speaker_id is not None and self.args.use_speaker_embedding: + if isinstance(speaker_id, str) and self.args.use_speaker_embedding: + # get the speaker id for the speaker embedding layer + _speaker_id = self.speaker_manager.name_to_id[speaker_id] + _speaker_id = id_to_torch(_speaker_id, cuda=is_cuda) + + if speaker_id is not None and self.args.use_d_vector_file: + # get the average d_vector for the speaker + d_vector = self.speaker_manager.get_mean_embedding(speaker_id, num_samples=None, randomize=False) + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference( + text_inputs, + aux_input={"d_vectors": d_vector, "speaker_ids": _speaker_id}, + pitch_transform=pitch_transform, + # energy_transform=energy_transform + ) + + # collect outputs + wav = outputs["model_outputs"][0].data.cpu().numpy() + alignments = outputs["alignments"] + return_dict = { + "wav": wav, + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + def synthesize_with_gl(self, text: str, speaker_id, d_vector): + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=None), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=is_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference_spec_decoder( + x=text_inputs, + aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}, + ) + + # collect outputs + S = outputs["model_outputs"].cpu().numpy()[0].T + S = db_to_amp_numpy(x=S, gain=1, base=None) + wav = mel_to_wav_numpy(mel=S, mel_basis=self.mel_basis, **self.config.audio) + alignments = outputs["alignments"] + return_dict = { + "wav": wav[None, :], + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + outputs = self.synthesize( + aux_inputs["text"], + config=self.config, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + ) + outputs_gl = self.synthesize_with_gl( + aux_inputs["text"], + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + ) + # speaker_name = self.speaker_manager.speaker_names[aux_inputs["speaker_id"]] + test_audios["{}-audio".format(idx)] = outputs["wav"].T + test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.speaker_names and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] + d_vectors = torch.FloatTensor(d_vectors) + + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + + ac = self.ap + + # compute spectrograms + batch["mel_input"] = wav_to_mel( + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + center=False, + ) + + # TODO: Align pitch properly + # assert ( + # batch["pitch"].shape[2] == batch["mel_input"].shape[2] + # ), f"{batch['pitch'].shape[2]}, {batch['mel_input'].shape[2]}" + batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]] if batch["pitch"] is not None else None + batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int() + + # zero the padding frames + batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1) + + # format attn priors as we now the max mel length + # TODO: fix 1 diff b/w mel_lengths and attn_priors + + if self.config.use_attn_priors: + attn_priors_np = batch["attn_priors"] + + batch["attn_priors"] = torch.zeros( + batch["mel_input"].shape[0], + batch["mel_lengths"].max(), + batch["text_lengths"].max(), + device=batch["mel_input"].device, + ) + + for i in range(batch["mel_input"].shape[0]): + batch["attn_priors"][i, : attn_priors_np[i].shape[0], : attn_priors_np[i].shape[1]] = torch.from_numpy( + attn_priors_np[i] + ) + + batch["energy"] = None + batch["energy"] = wav_to_energy( # [B, 1, T_max2] + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + center=False, + ) + batch["energy"] = self.energy_scaler(batch["energy"]) + return batch + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + print(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = ForwardTTSE2eDataset( + samples=samples, + ap=self.ap, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + compute_f0=config.compute_f0, + f0_cache_path=config.f0_cache_path, + attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences ascendingly by length + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=True, + ) + + # get pitch mean and std + self.pitch_mean = dataset.f0_dataset.mean + self.pitch_std = dataset.f0_dataset.std + return loader + + def get_criterion(self): + return [VitsDiscriminatorLoss(self.config), DelightfulTTSLoss(self.config)] + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + Returns: + List: optimizers. + """ + optimizer_disc = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc + ) + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer_gen = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + return [optimizer_disc, optimizer_gen] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_D = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler_D, scheduler_G] + + def on_epoch_end(self, trainer): # pylint: disable=unused-argument + # stop updating mean and var + # TODO: do the same for F0 + self.energy_scaler.eval() + + @staticmethod + def init_from_config( + config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False + ): # pylint: disable=unused-argument + """Initiate model from config + + Args: + config (ForwardTTSE2eConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config.model_args, samples) + ap = AudioProcessor.init_from_config(config=config) + return DelightfulTTS(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager, ap=ap) + + def load_checkpoint(self, config, checkpoint_path, eval=False): + """Load model from a checkpoint created by the 👟""" + # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_state_dict(self): + """Custom state dict of the model with all the necessary components for inference.""" + save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict} + + if hasattr(self, "emb_g"): + save_state["speaker_ids"] = self.speaker_manager.speaker_names + + if self.args.use_d_vector_file: + # TODO: implement saving of d_vectors + ... + return save_state + + def save(self, config, checkpoint_path): + """Save model to a file.""" + save_state = self.get_state_dict(config, checkpoint_path) # pylint: disable=too-many-function-args + save_state["pitch_mean"] = self.pitch_mean + save_state["pitch_std"] = self.pitch_std + torch.save(save_state, checkpoint_path) + + def on_train_step_start(self, trainer) -> None: + """Enable the discriminator training based on `steps_to_start_discriminator` + + Args: + trainer (Trainer): Trainer object. + """ + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 + self.train_disc = ( # pylint: disable=attribute-defined-outside-init + trainer.total_steps_done >= self.config.steps_to_start_discriminator + ) + + +class DelightfulTTSLoss(nn.Module): + def __init__(self, config): + super().__init__() + + self.mse_loss = nn.MSELoss() + self.mae_loss = nn.L1Loss() + self.forward_sum_loss = ForwardSumLoss() + self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params) + + self.mel_loss_alpha = config.mel_loss_alpha + self.aligner_loss_alpha = config.aligner_loss_alpha + self.pitch_loss_alpha = config.pitch_loss_alpha + self.energy_loss_alpha = config.energy_loss_alpha + self.u_prosody_loss_alpha = config.u_prosody_loss_alpha + self.p_prosody_loss_alpha = config.p_prosody_loss_alpha + self.dur_loss_alpha = config.dur_loss_alpha + self.char_dur_loss_alpha = config.char_dur_loss_alpha + self.binary_alignment_loss_alpha = config.binary_align_loss_alpha + + self.vocoder_mel_loss_alpha = config.vocoder_mel_loss_alpha + self.feat_loss_alpha = config.feat_loss_alpha + self.gen_loss_alpha = config.gen_loss_alpha + self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha + + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + def forward( + self, + mel_output, + mel_target, + mel_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + energy_output, + energy_target, + src_lens, + waveform, + waveform_hat, + p_prosody_ref, + p_prosody_pred, + u_prosody_ref, + u_prosody_pred, + aligner_logprob, + aligner_hard, + aligner_soft, + binary_loss_weight=None, + feats_fake=None, + feats_real=None, + scores_fake=None, + spec_slice=None, + spec_slice_hat=None, + skip_disc=False, + ): + """ + Shapes: + - mel_output: :math:`(B, C_mel, T_mel)` + - mel_target: :math:`(B, C_mel, T_mel)` + - mel_lens: :math:`(B)` + - dur_output: :math:`(B, T_src)` + - dur_target: :math:`(B, T_src)` + - pitch_output: :math:`(B, 1, T_src)` + - pitch_target: :math:`(B, 1, T_src)` + - energy_output: :math:`(B, 1, T_src)` + - energy_target: :math:`(B, 1, T_src)` + - src_lens: :math:`(B)` + - waveform: :math:`(B, 1, T_wav)` + - waveform_hat: :math:`(B, 1, T_wav)` + - p_prosody_ref: :math:`(B, T_src, 4)` + - p_prosody_pred: :math:`(B, T_src, 4)` + - u_prosody_ref: :math:`(B, 1, 256) + - u_prosody_pred: :math:`(B, 1, 256) + - aligner_logprob: :math:`(B, 1, T_mel, T_src)` + - aligner_hard: :math:`(B, T_mel, T_src)` + - aligner_soft: :math:`(B, T_mel, T_src)` + - spec_slice: :math:`(B, C_mel, T_mel)` + - spec_slice_hat: :math:`(B, C_mel, T_mel)` + """ + loss_dict = {} + src_mask = sequence_mask(src_lens).to(mel_output.device) # (B, T_src) + mel_mask = sequence_mask(mel_lens).to(mel_output.device) # (B, T_mel) + + dur_target.requires_grad = False + mel_target.requires_grad = False + pitch_target.requires_grad = False + + masked_mel_predictions = mel_output.masked_select(mel_mask[:, None]) + mel_targets = mel_target.masked_select(mel_mask[:, None]) + mel_loss = self.mae_loss(masked_mel_predictions, mel_targets) + + p_prosody_ref = p_prosody_ref.detach() + p_prosody_loss = 0.5 * self.mae_loss( + p_prosody_ref.masked_select(src_mask.unsqueeze(-1)), + p_prosody_pred.masked_select(src_mask.unsqueeze(-1)), + ) + + u_prosody_ref = u_prosody_ref.detach() + u_prosody_loss = 0.5 * self.mae_loss(u_prosody_ref, u_prosody_pred) + + duration_loss = self.mse_loss(dur_output, dur_target) + + pitch_output = pitch_output.masked_select(src_mask[:, None]) + pitch_target = pitch_target.masked_select(src_mask[:, None]) + pitch_loss = self.mse_loss(pitch_output, pitch_target) + + energy_output = energy_output.masked_select(src_mask[:, None]) + energy_target = energy_target.masked_select(src_mask[:, None]) + energy_loss = self.mse_loss(energy_output, energy_target) + + forward_sum_loss = self.forward_sum_loss(aligner_logprob, src_lens, mel_lens) + + total_loss = ( + (mel_loss * self.mel_loss_alpha) + + (duration_loss * self.dur_loss_alpha) + + (u_prosody_loss * self.u_prosody_loss_alpha) + + (p_prosody_loss * self.p_prosody_loss_alpha) + + (pitch_loss * self.pitch_loss_alpha) + + (energy_loss * self.energy_loss_alpha) + + (forward_sum_loss * self.aligner_loss_alpha) + ) + + if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) + total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + if binary_loss_weight: + loss_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) + else: + loss_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + + loss_dict["loss_aligner"] = self.aligner_loss_alpha * forward_sum_loss + loss_dict["loss_mel"] = self.mel_loss_alpha * mel_loss + loss_dict["loss_duration"] = self.dur_loss_alpha * duration_loss + loss_dict["loss_u_prosody"] = self.u_prosody_loss_alpha * u_prosody_loss + loss_dict["loss_p_prosody"] = self.p_prosody_loss_alpha * p_prosody_loss + loss_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + loss_dict["loss_energy"] = self.energy_loss_alpha * energy_loss + loss_dict["loss"] = total_loss + + # vocoder losses + if not skip_disc: + loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha + loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha + loss_dict["vocoder_loss_feat"] = loss_feat + loss_dict["vocoder_loss_gen"] = loss_gen + loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen + + loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.vocoder_mel_loss_alpha + loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform) + loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha + loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha + + loss_dict["vocoder_loss_mel"] = loss_mel + loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg + loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc + + loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_stft_sc + loss_stft_mg + return loss_dict diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py new file mode 100644 index 0000000..b6e9ac8 --- /dev/null +++ b/TTS/tts/models/forward_tts.py @@ -0,0 +1,862 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Union + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast + +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram +from TTS.utils.io import load_fsspec + + +@dataclass +class ForwardTTSArgs(Coqpit): + """ForwardTTS Model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 80. + + hidden_channels (int): + Number of base hidden channels of the model. Defaults to 512. + + use_aligner (bool): + Whether to use aligner network to learn the text to speech alignment or use pre-computed durations. + If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the + pre-computed durations must be provided to `config.datasets[0].meta_file_attn_mask`. Defaults to True. + + use_pitch (bool): + Use pitch predictor to learn the pitch. Defaults to True. + + use_energy (bool): + Use energy predictor to learn the energy. Defaults to True. + + duration_predictor_hidden_channels (int): + Number of hidden channels in the duration predictor. Defaults to 256. + + duration_predictor_dropout_p (float): + Dropout rate for the duration predictor. Defaults to 0.1. + + duration_predictor_kernel_size (int): + Kernel size of conv layers in the duration predictor. Defaults to 3. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + + energy_predictor_hidden_channels (int): + Number of hidden channels in the energy predictor. Defaults to 256. + + energy_predictor_dropout_p (float): + Dropout rate for the energy predictor. Defaults to 0.1. + + energy_predictor_kernel_size (int): + Kernel size of conv layers in the energy predictor. Defaults to 3. + + energy_embedding_kernel_size (int): + Kernel size of the projection layer in the energy predictor. Defaults to 3. + + positional_encoding (bool): + Whether to use positional encoding. Defaults to True. + + positional_encoding_use_scale (bool): + Whether to use a learnable scale coeff in the positional encoding. Defaults to True. + + length_scale (int): + Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0. + + encoder_type (str): + Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`. + Defaults to `fftransformer` as in the paper. + + encoder_params (dict): + Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + decoder_type (str): + Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`. + Defaults to `fftransformer` as in the paper. + + decoder_params (str): + Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + detach_duration_predictor (bool): + Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss + does not pass to the earlier layers. Defaults to True. + + max_duration (int): + Maximum duration accepted by the model. Defaults to 75. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 384 + use_aligner: bool = True + # pitch params + use_pitch: bool = True + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + + # energy params + use_energy: bool = False + energy_predictor_hidden_channels: int = 256 + energy_predictor_kernel_size: int = 3 + energy_predictor_dropout_p: float = 0.1 + energy_embedding_kernel_size: int = 3 + + # duration params + duration_predictor_hidden_channels: int = 256 + duration_predictor_kernel_size: int = 3 + duration_predictor_dropout_p: float = 0.1 + + positional_encoding: bool = True + poisitonal_encoding_use_scale: bool = True + length_scale: int = 1 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + detach_duration_predictor: bool = False + max_duration: int = 75 + num_speakers: int = 1 + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_dim: int = None + d_vector_file: str = None + + +class ForwardTTS(BaseTTS): + """General forward TTS model implementation that uses an encoder-decoder architecture with an optional alignment + network and a pitch predictor. + + If the alignment network is used, the model learns the text-to-speech alignment + from the data instead of using pre-computed durations. + + If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each + input character as in the FastPitch model. + + `ForwardTTS` can be configured to one of these architectures, + + - FastPitch + - SpeedySpeech + - FastSpeech + - FastSpeech2 (requires average speech energy predictor) + + Args: + config (Coqpit): Model coqpit class. + speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models. + Defaults to None. + + Examples: + >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs + >>> config = ForwardTTSArgs() + >>> model = ForwardTTS(config) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + self._set_model_args(config) + + self.init_multispeaker(config) + + self.max_duration = self.args.max_duration + self.use_aligner = self.args.use_aligner + self.use_pitch = self.args.use_pitch + self.use_energy = self.args.use_energy + self.binary_loss_weight = 0.0 + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + + self.encoder = Encoder( + self.args.hidden_channels, + self.args.hidden_channels, + self.args.encoder_type, + self.args.encoder_params, + self.embedded_speaker_dim, + ) + + if self.args.positional_encoding: + self.pos_encoder = PositionalEncoding(self.args.hidden_channels) + + self.decoder = Decoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.decoder_type, + self.args.decoder_params, + ) + + self.duration_predictor = DurationPredictor( + self.args.hidden_channels, + self.args.duration_predictor_hidden_channels, + self.args.duration_predictor_kernel_size, + self.args.duration_predictor_dropout_p, + ) + + if self.args.use_pitch: + self.pitch_predictor = DurationPredictor( + self.args.hidden_channels, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, + ) + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), + ) + + if self.args.use_energy: + self.energy_predictor = DurationPredictor( + self.args.hidden_channels, + self.args.energy_predictor_hidden_channels, + self.args.energy_predictor_kernel_size, + self.args.energy_predictor_dropout_p, + ) + self.energy_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.energy_embedding_kernel_size, + padding=int((self.args.energy_embedding_kernel_size - 1) / 2), + ) + + if self.args.use_aligner: + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels + ) + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # init speaker manager + if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding): + raise ValueError( + " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." + ) + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # init d-vector embedding + if config.use_d_vector_file: + self.embedded_speaker_dim = config.d_vector_dim + if self.args.d_vector_dim != self.args.hidden_channels: + #self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the durations. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Shapes: + - en: :math:`(B, D_{en}, T_{en})` + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + + Examples:: + + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + """Format predicted durations. + 1. Convert to linear scale from log scale + 2. Apply the length scale for speed adjustment + 3. Apply masking. + 4. Cast 0 durations to 1. + 5. Round the duration values. + + Args: + o_dr_log: Log scale durations. + x_mask: Input text mask. + + Shapes: + - o_dr_log: :math:`(B, T_{de})` + - x_mask: :math:`(B, T_{en})` + """ + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + def _forward_encoder( + self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Encoding forward pass. + + 1. Embed speaker IDs if multi-speaker mode. + 2. Embed character sequences. + 3. Run the encoder network. + 4. Sum encoder outputs and speaker embeddings + + Args: + x (torch.LongTensor): Input sequence IDs. + x_mask (torch.FloatTensor): Input squence mask. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. + + Returns: + Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings, + character embeddings + + Shapes: + - x: :math:`(B, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - g: :math:`(B, C)` + """ + if hasattr(self, "emb_g"): + g = g.type(torch.LongTensor) + g = self.emb_g(g) # [B, C, 1] + if g is not None: + g = g.unsqueeze(-1) + # [B, T, C] + x_emb = self.emb(x) + # encoder pass + #o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g) + # speaker conditioning + # TODO: try different ways of conditioning + if g is not None: + if hasattr(self, "proj_g"): + g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1) + o_en = o_en + g + return o_en, x_mask, g, x_emb + + def _forward_decoder( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.FloatTensor, + y_lengths: torch.IntTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Compute the decoder output mask + 2. Expand encoder output with the durations. + 3. Apply position encoding. + 4. Add speaker embeddings if multi-speaker mode. + 5. Run the decoder. + + Args: + o_en (torch.FloatTensor): Encoder output. + dr (torch.IntTensor): Ground truth durations or alignment network durations. + x_mask (torch.IntTensor): Input sequence mask. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de.transpose(1, 2), attn.transpose(1, 2) + + def _forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_over_durations(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def _forward_energy_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + energy: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Energy predictor forward pass. + + 1. Predict energy from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average energy values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + energy (torch.FloatTensor, optional): Ground truth energy values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Energy embedding, energy prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_energy = self.energy_predictor(o_en, x_mask) + if energy is not None: + avg_energy = average_over_durations(energy, dr) + o_energy_emb = self.energy_emb(avg_energy) + return o_energy_emb, o_energy, avg_energy + o_energy_emb = self.energy_emb(o_energy) + return o_energy_emb, o_energy + + def _forward_aligner( + self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + + - o_alignment_dur: :math:`[B, T_en]` + - alignment_soft: :math:`[B, T_en, T_de]` + - alignment_logprob: :math:`[B, 1, T_de, T_en]` + - alignment_mas: :math:`[B, T_en, T_de]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.sum(alignment_mas, -1).int() + alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) + return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + y_lengths: torch.LongTensor, + y: torch.FloatTensor = None, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + energy: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. + y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + energy (torch.FloatTensor): energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - y: :math:`[B, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` + """ + g = self._set_speaker_input(aux_input) + # compute sequence masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() + # encoder pass + o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + # duration predictor pass + if self.args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # generate attn mask from predicted durations + o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + # aligner + o_alignment_dur = None + alignment_soft = None + alignment_logprob = None + alignment_mas = None + if self.use_aligner: + o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( + x_emb, y, x_mask, y_mask + ) + alignment_soft = alignment_soft.transpose(1, 2) + alignment_mas = alignment_mas.transpose(1, 2) + dr = o_alignment_dur + # pitch predictor pass + o_pitch = None + avg_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + # energy predictor pass + o_energy = None + avg_energy = None + if self.args.use_energy: + o_energy_emb, o_energy, avg_energy = self._forward_energy_predictor(o_en, x_mask, energy, dr) + o_en = o_en + o_energy_emb + # decoder pass + o_de, attn = self._forward_decoder( + o_en, dr, x_mask, y_lengths, g=None + ) # TODO: maybe pass speaker embedding (g) too + outputs = { + "model_outputs": o_de, # [B, T, C] + "durations_log": o_dr_log.squeeze(1), # [B, T] + "durations": o_dr.squeeze(1), # [B, T] + "attn_durations": o_attn, # for visualization [B, T_en, T_de'] + "pitch_avg": o_pitch, + "pitch_avg_gt": avg_pitch, + "energy_avg": o_energy, + "energy_avg_gt": avg_energy, + "alignments": attn, # [B, T_de, T_en] + "alignment_soft": alignment_soft, + "alignment_mas": alignment_mas, + "o_alignment_dur": o_alignment_dur, + "alignment_logprob": alignment_logprob, + "x_mask": x_mask, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + + Shapes: + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] + """ + g = self._set_speaker_input(aux_input) + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + + # pitch predictor pass + o_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) + o_en = o_en + o_pitch_emb + # energy predictor pass + o_energy = None + if self.args.use_energy: + o_energy_emb, o_energy = self._forward_energy_predictor(o_en, x_mask) + o_en = o_en + o_energy_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) + outputs = { + "model_outputs": o_de, + "alignments": attn, + "pitch": o_pitch, + "energy": o_energy, + "durations_log": o_dr_log, + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + pitch = batch["pitch"] if self.args.use_pitch else None + energy = batch["energy"] if self.args.use_energy else None + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + durations = batch["durations"] + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + + # forward pass + outputs = self.forward( + text_input, + text_lengths, + mel_lengths, + y=mel_input, + dr=durations, + pitch=pitch, + energy=energy, + aux_input=aux_input, + ) + # use aligner's output as the duration target + if self.use_aligner: + durations = outputs["o_alignment_dur"] + # use float32 in AMP + with autocast(enabled=False): + # compute loss + loss_dict = criterion( + decoder_output=outputs["model_outputs"], + decoder_target=mel_input, + decoder_output_lens=mel_lengths, + dur_output=outputs["durations_log"], + dur_target=durations, + pitch_output=outputs["pitch_avg"] if self.use_pitch else None, + pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, + energy_output=outputs["energy_avg"] if self.use_energy else None, + energy_target=outputs["energy_avg_gt"] if self.use_energy else None, + input_lens=text_lengths, + alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, + alignment_soft=outputs["alignment_soft"], + alignment_hard=outputs["alignment_mas"], + binary_loss_weight=self.binary_loss_weight, + ) + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error + + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): + """Create common logger outputs.""" + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + if self.args.use_pitch: + pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot energy figures + if self.args.use_energy: + energy_avg = abs(outputs["energy_avg_gt"][0, 0].data.cpu().numpy()) + energy_avg_hat = abs(outputs["energy_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + energy_figures = { + "energy_ground_truth": plot_avg_energy(energy_avg, chars, output_fig=False), + "energy_avg_predicted": plot_avg_energy(energy_avg_hat, chars, output_fig=False), + } + figures.update(energy_figures) + + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs: + alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import ForwardTTSLoss # pylint: disable=import-outside-toplevel + + return ForwardTTSLoss(self.config) + + def on_train_step_start(self, trainer): + """Schedule binary loss weight.""" + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 + + @staticmethod + def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (ForwardTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return ForwardTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py new file mode 100644 index 0000000..bfd1a2b --- /dev/null +++ b/TTS/tts/models/glow_tts.py @@ -0,0 +1,557 @@ +import math +from typing import Dict, List, Tuple, Union + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F + +from TTS.tts.configs.glow_tts_config import GlowTTSConfig +from TTS.tts.layers.glow_tts.decoder import Decoder +from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.io import load_fsspec + + +class GlowTTS(BaseTTS): + """GlowTTS model. + + Paper:: + https://arxiv.org/abs/2005.11129 + + Paper abstract:: + Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate + mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained + without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS, + a flow-based generative model for parallel TTS that does not require any external aligner. By combining the + properties of flows and dynamic programming, the proposed model searches for the most probable monotonic + alignment between text and the latent representation of speech on its own. We demonstrate that enforcing hard + monotonic alignments enables robust TTS, which generalizes to long utterances, and employing generative flows + enables fast, diverse, and controllable speech synthesis. Glow-TTS obtains an order-of-magnitude speed-up over + the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our + model can be easily extended to a multi-speaker setting. + + Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. + + Examples: + Init only model layers. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig(num_chars=2) + >>> model = GlowTTS(config) + + Fully init a model ready for action. All the class attributes and class members + (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig() + >>> model = GlowTTS.init_from_config(config, verbose=False) + """ + + def __init__( + self, + config: GlowTTSConfig, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.decoder_output_dim = config.out_channels + + # init multi-speaker layers if necessary + self.init_multispeaker(config) + + self.run_data_dep_init = config.data_dep_init_steps > 0 + self.encoder = Encoder( + self.num_chars, + out_channels=self.out_channels, + hidden_channels=self.hidden_channels_enc, + hidden_channels_dp=self.hidden_channels_dp, + encoder_type=self.encoder_type, + encoder_params=self.encoder_params, + mean_only=self.mean_only, + use_prenet=self.use_encoder_prenet, + dropout_p_dp=self.dropout_p_dp, + c_in_channels=self.c_in_channels, + ) + + self.decoder = Decoder( + self.out_channels, + self.hidden_channels_dec, + self.kernel_size_dec, + self.dilation_rate, + self.num_flow_blocks_dec, + self.num_block_layers, + dropout_p=self.dropout_p_dec, + num_splits=self.num_splits, + num_squeeze=self.num_squeeze, + sigmoid_scale=self.sigmoid_scale, + c_in_channels=self.c_in_channels, + ) + + def init_multispeaker(self, config: Coqpit): + """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding + vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets + speaker embedding vector dimension to the d-vector dimension from the config. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # set ultimate speaker embedding size + if config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + if self.speaker_manager is not None: + assert ( + config.d_vector_dim == self.speaker_manager.embedding_dim + ), " [!] d-vector dimension mismatch b/w config and speaker manager." + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.embedded_speaker_dim = self.hidden_channels_enc + self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + # set conditioning dimensions + self.c_in_channels = self.embedded_speaker_dim + + @staticmethod + def compute_outputs(attn, o_mean, o_log_scale, x_mask): + """Compute and format the mode outputs with the given alignment map""" + y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + # compute total duration with adjustment + o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask + return y_mean, y_log_scale, o_attn_dur + + def unlock_act_norm_layers(self): + """Unlock activation normalization layers for data depended initalization.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(True) + + def lock_act_norm_layers(self): + """Lock activation normalization layers.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(False) + + def _set_speaker_input(self, aux_input: Dict): + if aux_input is None: + d_vectors = None + speaker_ids = None + else: + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]: + g = self._set_speaker_input(aux_input) + # speaker embedding + if g is not None: + if hasattr(self, "emb_g"): + # use speaker embedding layer + if not g.size(): # if is a scalar + g = g.unsqueeze(0) # unsqueeze + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + else: + # use d-vector + g = F.normalize(g).unsqueeze(-1) # [b, h, 1] + return g + + def forward( + self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + + aux_input (Dict): + Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model. + :math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding + layer. :math:`B` + + Returns: + Dict: + - z: :math: `[B, T_de, C]` + - logdet: :math:`B` + - y_mean: :math:`[B, T_de, C]` + - y_log_scale: :math:`[B, T_de, C]` + - alignments: :math:`[B, T_en, T_de]` + - durations_log: :math:`[B, T_en, 1]` + - total_durations_log: :math:`[B, T_en, 1]` + """ + # [B, T, C] -> [B, C, T] + y = y.transpose(1, 2) + y_max_length = y.size(2) + # norm speaker embeddings + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + # [B, 1, T_en, T_de] + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path + with torch.no_grad(): + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + outputs = { + "z": z.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + @torch.no_grad() + def inference_with_MAS( + self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + It's similar to the teacher forcing in Tacotron. + It was proposed in: https://arxiv.org/abs/2104.05557 + + Shapes: + - x: :math:`[B, T]` + - x_lenghts: :math:`B` + - y: :math:`[B, T, C]` + - y_lengths: :math:`B` + - g: :math:`[B, C] or B` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + # norm speaker embeddings + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path between z and encoder output + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + + # get predited aligned distribution + z = y_mean * y_mask + + # reverse the decoder and predict using the aligned distribution + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + outputs = { + "model_outputs": z.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + @torch.no_grad() + def decoder_inference( + self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Shapes: + - y: :math:`[B, T, C]` + - y_lengths: :math:`B` + - g: :math:`[B, C] or B` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + g = self._speaker_embedding(aux_input) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # reverse decoder and predict + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + outputs = {} + outputs["model_outputs"] = y.transpose(1, 2) + outputs["logdet"] = logdet + return outputs + + @torch.no_grad() + def inference( + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + x_lengths = aux_input["x_lengths"] + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # compute output durations + w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale + w_ceil = torch.clamp_min(torch.ceil(w), 1) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = None + # compute masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # compute attention mask + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + + z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask + # decoder pass + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + attn = attn.squeeze(1).permute(0, 2, 1) + outputs = { + "model_outputs": y.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + """A single training step. Forward pass and loss computation. Run data depended initialization for the + first `config.data_dep_init_steps` steps. + + Args: + batch (dict): [description] + criterion (nn.Module): [description] + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + + if self.run_data_dep_init and self.training: + # compute data-dependent initialization of activation norm layers + self.unlock_act_norm_layers() + with torch.no_grad(): + _ = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + outputs = None + loss_dict = None + self.lock_act_norm_layers() + else: + # normal training step + outputs = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + + with autocast(enabled=False): # avoid mixed_precision in criterion + loss_dict = criterion( + outputs["z"].float(), + outputs["y_mean"].float(), + outputs["y_log_scale"].float(), + outputs["logdet"].float(), + mel_lengths, + outputs["durations_log"].float(), + outputs["total_durations_log"].float(), + text_lengths, + ) + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): + alignments = outputs["alignments"] + text_input = batch["text_input"][:1] if batch["text_input"] is not None else None + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None + speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None + + # model runs reverse flow to predict spectrograms + pred_outputs = self.inference( + text_input, + aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + model_outputs = pred_outputs["model_outputs"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + if len(test_sentences) == 0: + print(" | [!] No test sentences provided.") + else: + for idx, sen in enumerate(test_sentences): + outputs = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + + test_audios["{}-audio".format(idx)] = outputs["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return test_figures, test_audios + + def preprocess(self, y, y_lengths, y_max_length, attn=None): + if y_max_length is not None: + y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze + y = y[:, :, :y_max_length] + if attn is not None: + attn = attn[:, :, :, :y_max_length] + y_lengths = torch.div(y_lengths, self.num_squeeze, rounding_mode="floor") * self.num_squeeze + return y, y_lengths, y_max_length, attn + + def store_inverse(self): + self.decoder.store_inverse() + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + self.store_inverse() + assert not self.training + + @staticmethod + def get_criterion(): + from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel + + return GlowTTSLoss() + + def on_train_step_start(self, trainer): + """Decide on every training step wheter enable/disable data depended initialization.""" + self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps + + @staticmethod + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py new file mode 100644 index 0000000..e241410 --- /dev/null +++ b/TTS/tts/models/neuralhmm_tts.py @@ -0,0 +1,385 @@ +import os +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.logging.tensorboard_logger import TensorboardLogger + +from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils +from TTS.tts.layers.overflow.neural_hmm import NeuralHMM +from TTS.tts.layers.overflow.plotting_utils import ( + get_spec_from_most_probable_state, + plot_transition_probabilities_to_numpy, +) +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec + + +class NeuralhmmTTS(BaseTTS): + """Neural HMM TTS model. + + Paper:: + https://arxiv.org/abs/2108.13320 + + Paper abstract:: + Neural sequence-to-sequence TTS has achieved significantly better output quality + than statistical speech synthesis using HMMs.However, neural TTS is generally not probabilistic + and uses non-monotonic attention. Attention failures increase training time and can make + synthesis babble incoherently. This paper describes how the old and new paradigms can be + combined to obtain the advantages of both worlds, by replacing attention in neural TTS with + an autoregressive left-right no-skip hidden Markov model defined by a neural network. + Based on this proposal, we modify Tacotron 2 to obtain an HMM-based neural TTS model with + monotonic alignment, trained to maximise the full sequence likelihood without approximation. + We also describe how to combine ideas from classical and contemporary TTS for best results. + The resulting example system is smaller and simpler than Tacotron 2, and learns to speak with + fewer iterations and less data, whilst achieving comparable naturalness prior to the post-net. + Our approach also allows easy control over speaking rate. Audio examples and code + are available at https://shivammehta25.github.io/Neural-HMM/ . + + Note: + - This is a parameter efficient version of OverFlow (15.3M vs 28.6M). Since it has half the + number of parameters as OverFlow the synthesis output quality is suboptimal (but comparable to Tacotron2 + without Postnet), but it learns to speak with even lesser amount of data and is still significantly faster + than other attention-based methods. + + - Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities + of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning + If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and + `mel_statistics_parameter_path` accordingly. + + - To enable multi-GPU training, set the `use_grad_checkpointing=False` in config. + This will significantly increase the memory usage. This is because to compute + the actual data likelihood (not an approximation using MAS/Viterbi) we must use + all the states at the previous time step during the forward pass to decide the + probability distribution at the current step i.e the difference between the forward + algorithm and viterbi approximation. + + Check :class:`TTS.tts.configs.neuralhmm_tts_config.NeuralhmmTTSConfig` for class arguments. + """ + + def __init__( + self, + config: "NeuralhmmTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features) + self.neural_hmm = NeuralHMM( + frame_channels=self.out_channels, + ar_order=self.ar_order, + deterministic_transition=self.deterministic_transition, + encoder_dim=self.encoder_in_out_features, + prenet_type=self.prenet_type, + prenet_dim=self.prenet_dim, + prenet_n_layers=self.prenet_n_layers, + prenet_dropout=self.prenet_dropout, + prenet_dropout_at_inference=self.prenet_dropout_at_inference, + memory_rnn_dim=self.memory_rnn_dim, + outputnet_size=self.outputnet_size, + flat_start_params=self.flat_start_params, + std_floor=self.std_floor, + use_grad_checkpointing=self.use_grad_checkpointing, + ) + + self.register_buffer("mean", torch.tensor(0)) + self.register_buffer("std", torch.tensor(1)) + + def update_mean_std(self, statistics_dict: Dict): + self.mean.data = torch.tensor(statistics_dict["mean"]) + self.std.data = torch.tensor(statistics_dict["std"]) + + def preprocess_batch(self, text, text_len, mels, mel_len): + if self.mean.item() == 0 or self.std.item() == 1: + statistics_dict = torch.load(self.mel_statistics_parameter_path) + self.update_mean_std(statistics_dict) + + mels = self.normalize(mels) + return text, text_len, mels, mel_len + + def normalize(self, x): + return x.sub(self.mean).div(self.std) + + def inverse_normalize(self, x): + return x.mul(self.std).add(self.mean) + + def forward(self, text, text_len, mels, mel_len): + """ + Forward pass for training and computing the log likelihood of a given batch. + + Shapes: + Shapes: + text: :math:`[B, T_in]` + text_len: :math:`[B]` + mels: :math:`[B, T_out, C]` + mel_len: :math:`[B]` + """ + text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len) + encoder_outputs, encoder_output_len = self.encoder(text, text_len) + + log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm( + encoder_outputs, encoder_output_len, mels.transpose(1, 2), mel_len + ) + + outputs = { + "log_probs": log_probs, + "alignments": fwd_alignments, + "transition_vectors": transition_vectors, + "means": means, + } + + return outputs + + @staticmethod + def _training_stats(batch): + stats = {} + stats["avg_text_length"] = batch["text_lengths"].float().mean() + stats["avg_spec_length"] = batch["mel_lengths"].float().mean() + stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean() + stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean() + return stats + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + + outputs = self.forward( + text=text_input, + text_len=text_lengths, + mels=mel_input, + mel_len=mel_lengths, + ) + loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum())) + + # for printing useful statistics on terminal + loss_dict.update(self._training_stats(batch)) + return outputs, loss_dict + + def eval_step(self, batch: Dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def _format_aux_input(self, aux_input: Dict, default_input_dict): + """Set missing fields to their default value. + + Args: + aux_inputs (Dict): Dictionary containing the auxiliary inputs. + """ + default_input_dict = default_input_dict.copy() + default_input_dict.update( + { + "sampling_temp": self.sampling_temp, + "max_sampling_time": self.max_sampling_time, + "duration_threshold": self.duration_threshold, + } + ) + if aux_input: + return format_aux_input(default_input_dict, aux_input) + return default_input_dict + + @torch.no_grad() + def inference( + self, + text: torch.Tensor, + aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None}, + ): # pylint: disable=dangerous-default-value + """Sampling from the model + + Args: + text (torch.Tensor): :math:`[B, T_in]` + aux_inputs (_type_, optional): _description_. Defaults to None. + + Returns: + outputs: Dictionary containing the following + - mel (torch.Tensor): :math:`[B, T_out, C]` + - hmm_outputs_len (torch.Tensor): :math:`[B]` + - state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch. + - input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM. + - output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM. + """ + default_input_dict = { + "x_lengths": torch.sum(text != 0, dim=1), + } + aux_input = self._format_aux_input(aux_input, default_input_dict) + encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"]) + outputs = self.neural_hmm.inference( + encoder_outputs, + encoder_output_len, + sampling_temp=aux_input["sampling_temp"], + max_sampling_time=aux_input["max_sampling_time"], + duration_threshold=aux_input["duration_threshold"], + ) + mels, mel_outputs_len = outputs["hmm_outputs"], outputs["hmm_outputs_len"] + + mels = self.inverse_normalize(mels) + outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len}) + outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"]) + return outputs + + @staticmethod + def get_criterion(): + return NLLLoss() + + @staticmethod + def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager) + + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def on_init_start(self, trainer): + """If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads.""" + if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics: + dataloader = trainer.get_train_dataloader( + training_assets=None, samples=trainer.train_samples, verbose=False + ) + print( + f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + ) + data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( + dataloader, trainer.config.out_channels, trainer.config.state_per_phone + ) + print( + f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + ) + statistics = { + "mean": data_mean.item(), + "std": data_std.item(), + "init_transition_prob": init_transition_prob.item(), + } + torch.save(statistics, trainer.config.mel_statistics_parameter_path) + + else: + print( + f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + ) + statistics = torch.load(trainer.config.mel_statistics_parameter_path) + data_mean, data_std, init_transition_prob = ( + statistics["mean"], + statistics["std"], + statistics["init_transition_prob"], + ) + print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + + trainer.config.flat_start_params["transition_p"] = ( + init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob + ) + OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob) + trainer.model.update_mean_std(statistics) + + @torch.inference_mode() + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument + alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"] + means = torch.stack(outputs["means"], dim=1) + + figures = { + "alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)), + "log_alignment": plot_alignment( + alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20) + ), + "transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)), + "mel_from_most_probable_state": plot_spectrogram( + get_spec_from_most_probable_state(alignments[0], means[0]), fig_size=(12, 3) + ), + "mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)), + } + + # sample one item from the batch -1 will give the smalles item + print(" | > Synthesising audio from the model...") + inference_output = self.inference( + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} + ) + figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) + + states = [p[1] for p in inference_output["input_parameters"][0]] + transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]] + + for i in range((len(transition_probability_synthesising) // 200) + 1): + start = i * 200 + end = (i + 1) * 200 + figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy( + states[start:end], transition_probability_synthesising[start:end] + ) + + audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy()) + return figures, {"audios": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=unused-argument + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int + ): # pylint: disable=unused-argument + """Compute and log evaluation metrics.""" + # Plot model parameters histograms + if isinstance(logger, TensorboardLogger): + # I don't know if any other loggers supports this + for tag, value in self.named_parameters(): + tag = tag.replace(".", "/") + logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps) + + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs[1], self.ap.sample_rate) + logger.test_figures(steps, outputs[0]) + + +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py new file mode 100644 index 0000000..92b3c76 --- /dev/null +++ b/TTS/tts/models/overflow.py @@ -0,0 +1,401 @@ +import os +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.logging.tensorboard_logger import TensorboardLogger + +from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils +from TTS.tts.layers.overflow.decoder import Decoder +from TTS.tts.layers.overflow.neural_hmm import NeuralHMM +from TTS.tts.layers.overflow.plotting_utils import ( + get_spec_from_most_probable_state, + plot_transition_probabilities_to_numpy, +) +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec + + +class Overflow(BaseTTS): + """OverFlow TTS model. + + Paper:: + https://arxiv.org/abs/2211.06892 + + Paper abstract:: + Neural HMMs are a type of neural transducer recently proposed for + sequence-to-sequence modelling in text-to-speech. They combine the best features + of classic statistical speech synthesis and modern neural TTS, requiring less + data and fewer training updates, and are less prone to gibberish output caused + by neural attention failures. In this paper, we combine neural HMM TTS with + normalising flows for describing the highly non-Gaussian distribution of speech + acoustics. The result is a powerful, fully probabilistic model of durations and + acoustics that can be trained using exact maximum likelihood. Compared to + dominant flow-based acoustic models, our approach integrates autoregression for + improved modelling of long-range dependences such as utterance-level prosody. + Experiments show that a system based on our proposal gives more accurate + pronunciations and better subjective speech quality than comparable methods, + whilst retaining the original advantages of neural HMMs. Audio examples and code + are available at https://shivammehta25.github.io/OverFlow/. + + Note: + - Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities + of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning + If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and + `mel_statistics_parameter_path` accordingly. + + - To enable multi-GPU training, set the `use_grad_checkpointing=False` in config. + This will significantly increase the memory usage. This is because to compute + the actual data likelihood (not an approximation using MAS/Viterbi) we must use + all the states at the previous time step during the forward pass to decide the + probability distribution at the current step i.e the difference between the forward + algorithm and viterbi approximation. + + Check :class:`TTS.tts.configs.overflow.OverFlowConfig` for class arguments. + """ + + def __init__( + self, + config: "OverFlowConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.decoder_output_dim = config.out_channels + + self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features) + self.neural_hmm = NeuralHMM( + frame_channels=self.out_channels, + ar_order=self.ar_order, + deterministic_transition=self.deterministic_transition, + encoder_dim=self.encoder_in_out_features, + prenet_type=self.prenet_type, + prenet_dim=self.prenet_dim, + prenet_n_layers=self.prenet_n_layers, + prenet_dropout=self.prenet_dropout, + prenet_dropout_at_inference=self.prenet_dropout_at_inference, + memory_rnn_dim=self.memory_rnn_dim, + outputnet_size=self.outputnet_size, + flat_start_params=self.flat_start_params, + std_floor=self.std_floor, + use_grad_checkpointing=self.use_grad_checkpointing, + ) + + self.decoder = Decoder( + self.out_channels, + self.hidden_channels_dec, + self.kernel_size_dec, + self.dilation_rate, + self.num_flow_blocks_dec, + self.num_block_layers, + dropout_p=self.dropout_p_dec, + num_splits=self.num_splits, + num_squeeze=self.num_squeeze, + sigmoid_scale=self.sigmoid_scale, + c_in_channels=self.c_in_channels, + ) + + self.register_buffer("mean", torch.tensor(0)) + self.register_buffer("std", torch.tensor(1)) + + def update_mean_std(self, statistics_dict: Dict): + self.mean.data = torch.tensor(statistics_dict["mean"]) + self.std.data = torch.tensor(statistics_dict["std"]) + + def preprocess_batch(self, text, text_len, mels, mel_len): + if self.mean.item() == 0 or self.std.item() == 1: + statistics_dict = torch.load(self.mel_statistics_parameter_path) + self.update_mean_std(statistics_dict) + + mels = self.normalize(mels) + return text, text_len, mels, mel_len + + def normalize(self, x): + return x.sub(self.mean).div(self.std) + + def inverse_normalize(self, x): + return x.mul(self.std).add(self.mean) + + def forward(self, text, text_len, mels, mel_len): + """ + Forward pass for training and computing the log likelihood of a given batch. + + Shapes: + Shapes: + text: :math:`[B, T_in]` + text_len: :math:`[B]` + mels: :math:`[B, T_out, C]` + mel_len: :math:`[B]` + """ + text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len) + encoder_outputs, encoder_output_len = self.encoder(text, text_len) + z, z_lengths, logdet = self.decoder(mels.transpose(1, 2), mel_len) + log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm( + encoder_outputs, encoder_output_len, z, z_lengths + ) + + outputs = { + "log_probs": log_probs + logdet, + "alignments": fwd_alignments, + "transition_vectors": transition_vectors, + "means": means, + } + + return outputs + + @staticmethod + def _training_stats(batch): + stats = {} + stats["avg_text_length"] = batch["text_lengths"].float().mean() + stats["avg_spec_length"] = batch["mel_lengths"].float().mean() + stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean() + stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean() + return stats + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + + outputs = self.forward( + text=text_input, + text_len=text_lengths, + mels=mel_input, + mel_len=mel_lengths, + ) + loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum())) + + # for printing useful statistics on terminal + loss_dict.update(self._training_stats(batch)) + return outputs, loss_dict + + def eval_step(self, batch: Dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def _format_aux_input(self, aux_input: Dict, default_input_dict): + """Set missing fields to their default value. + + Args: + aux_inputs (Dict): Dictionary containing the auxiliary inputs. + """ + default_input_dict = default_input_dict.copy() + default_input_dict.update( + { + "sampling_temp": self.sampling_temp, + "max_sampling_time": self.max_sampling_time, + "duration_threshold": self.duration_threshold, + } + ) + if aux_input: + return format_aux_input(default_input_dict, aux_input) + return default_input_dict + + @torch.no_grad() + def inference( + self, + text: torch.Tensor, + aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None}, + ): # pylint: disable=dangerous-default-value + """Sampling from the model + + Args: + text (torch.Tensor): :math:`[B, T_in]` + aux_inputs (_type_, optional): _description_. Defaults to None. + + Returns: + outputs: Dictionary containing the following + - mel (torch.Tensor): :math:`[B, T_out, C]` + - hmm_outputs_len (torch.Tensor): :math:`[B]` + - state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch. + - input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM. + - output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM. + """ + default_input_dict = { + "x_lengths": torch.sum(text != 0, dim=1), + } + aux_input = self._format_aux_input(aux_input, default_input_dict) + encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"]) + outputs = self.neural_hmm.inference( + encoder_outputs, + encoder_output_len, + sampling_temp=aux_input["sampling_temp"], + max_sampling_time=aux_input["max_sampling_time"], + duration_threshold=aux_input["duration_threshold"], + ) + + mels, mel_outputs_len, _ = self.decoder( + outputs["hmm_outputs"].transpose(1, 2), outputs["hmm_outputs_len"], reverse=True + ) + mels = self.inverse_normalize(mels.transpose(1, 2)) + outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len}) + outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"]) + return outputs + + @staticmethod + def get_criterion(): + return NLLLoss() + + @staticmethod + def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Overflow(new_config, ap, tokenizer, speaker_manager) + + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + self.decoder.store_inverse() + assert not self.training + + def on_init_start(self, trainer): + """If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads.""" + if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics: + dataloader = trainer.get_train_dataloader( + training_assets=None, samples=trainer.train_samples, verbose=False + ) + print( + f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + ) + data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( + dataloader, trainer.config.out_channels, trainer.config.state_per_phone + ) + print( + f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + ) + statistics = { + "mean": data_mean.item(), + "std": data_std.item(), + "init_transition_prob": init_transition_prob.item(), + } + torch.save(statistics, trainer.config.mel_statistics_parameter_path) + + else: + print( + f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + ) + statistics = torch.load(trainer.config.mel_statistics_parameter_path) + data_mean, data_std, init_transition_prob = ( + statistics["mean"], + statistics["std"], + statistics["init_transition_prob"], + ) + print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + + trainer.config.flat_start_params["transition_p"] = ( + init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob + ) + OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob) + trainer.model.update_mean_std(statistics) + + @torch.inference_mode() + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument + alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"] + means = torch.stack(outputs["means"], dim=1) + + figures = { + "alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)), + "log_alignment": plot_alignment( + alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20) + ), + "transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)), + "mel_from_most_probable_state": plot_spectrogram( + get_spec_from_most_probable_state(alignments[0], means[0], self.decoder), fig_size=(12, 3) + ), + "mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)), + } + + # sample one item from the batch -1 will give the smalles item + print(" | > Synthesising audio from the model...") + inference_output = self.inference( + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} + ) + figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) + + states = [p[1] for p in inference_output["input_parameters"][0]] + transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]] + + for i in range((len(transition_probability_synthesising) // 200) + 1): + start = i * 200 + end = (i + 1) * 200 + figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy( + states[start:end], transition_probability_synthesising[start:end] + ) + + audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy()) + return figures, {"audios": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=unused-argument + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int + ): # pylint: disable=unused-argument + """Compute and log evaluation metrics.""" + # Plot model parameters histograms + if isinstance(logger, TensorboardLogger): + # I don't know if any other loggers supports this + for tag, value in self.named_parameters(): + tag = tag.replace(".", "/") + logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps) + + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs[1], self.ap.sample_rate) + logger.test_figures(steps, outputs[0]) + + +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py new file mode 100644 index 0000000..474ec46 --- /dev/null +++ b/TTS/tts/models/tacotron.py @@ -0,0 +1,409 @@ +# coding: utf-8 + +from typing import Dict, List, Tuple, Union + +import torch +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE +from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG +from TTS.tts.models.base_tacotron import BaseTacotron +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.capacitron_optimizer import CapacitronOptimizer + + +class Tacotron(BaseTacotron): + """Tacotron as in https://arxiv.org/abs/1703.10135 + It's an autoregressive encoder-attention-decoder-postnet architecture. + Check `TacotronConfig` for the arguments. + + Args: + config (TacotronConfig): Configuration for the Tacotron model. + speaker_manager (SpeakerManager): Speaker manager to handle multi-speaker settings. Only use if the model is + a multi-speaker model. Defaults to None. + """ + + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) + + # set speaker embedding channel size for determining `in_channels` for the connected layers. + # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based + # on the number of speakers infered from the dataset. + if self.use_speaker_embedding or self.use_d_vector_file: + self.init_multispeaker(config) + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim + + if self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim + + # embedding layer + self.embedding = nn.Embedding(self.num_chars, 256, padding_idx=0) + self.embedding.weight.data.normal_(0, 0.3) + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + self.postnet = PostCBHG(self.decoder_output_dim) + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.out_channels) + + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference + + # global style token layers + if self.gst and self.use_gst: + self.gst_layer = GST( + num_mel=self.decoder_output_dim, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, + ) + + # Capacitron layers + if self.capacitron_vae and self.use_capacitron_vae: + self.capacitron_vae_layer = CapacitronVAE( + num_mel=self.decoder_output_dim, + encoder_output_dim=self.encoder_in_features, + capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, + speaker_embedding_dim=self.embedded_speaker_dim + if self.use_speaker_embedding and self.capacitron_vae.capacitron_use_speaker_embedding + else None, + text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + ) + + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self.coarse_decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.ddc_r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): + """ + Shapes: + text: [B, T_in] + text_lengths: [B] + mel_specs: [B, T_out, C] + mel_lengths: [B] + aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C] + """ + aux_input = self._format_aux_input(aux_input) + outputs = {"alignments_backward": None, "decoder_outputs_backward": None} + inputs = self.embedding(text) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) + # B x T_in x encoder_in_features + encoder_outputs = self.encoder(inputs) + # sequence masking + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + # global style token + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + # speaker embedding + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + # Capacitron + if self.capacitron_vae and self.use_capacitron_vae: + # B x capacitron_VAE_embedding_dim + encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[mel_specs, mel_lengths], + text_info=[inputs, text_lengths] + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, + ) + else: + capacitron_vae_outputs = None + # decoder_outputs: B x decoder_in_features x T_out + # alignments: B x T_in x encoder_in_features + # stop_tokens: B x T_in + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + # sequence masking + if output_mask is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x T_out x decoder_in_features + postnet_outputs = self.postnet(decoder_outputs) + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs) + # B x T_out x posnet_dim + postnet_outputs = self.last_linear(postnet_outputs) + # B x T_out x decoder_in_features + decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + outputs.update( + { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + "capacitron_vae_outputs": capacitron_vae_outputs, + } + ) + return outputs + + @torch.no_grad() + def inference(self, text_input, aux_input=None): + aux_input = self._format_aux_input(aux_input) + inputs = self.embedding(text_input) + encoder_outputs = self.encoder(inputs) + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + if self.capacitron_vae and self.use_capacitron_vae: + if aux_input["style_text"] is not None: + style_text_embedding = self.embedding(aux_input["style_text"]) + style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to( + encoder_outputs.device + ) # pylint: disable=not-callable + reference_mel_length = ( + torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_mel"] is not None + else None + ) # pylint: disable=not-callable + # B x capacitron_VAE_embedding_dim + encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[aux_input["style_mel"], reference_mel_length] + if aux_input["style_mel"] is not None + else None, + text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, + speaker_embedding=aux_input["d_vectors"] + if self.capacitron_vae.capacitron_use_speaker_embedding + else None, + ) + if self.num_speakers > 1: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"]) + # reshape embedded_speakers + if embedded_speakers.ndim == 1: + embedded_speakers = embedded_speakers[None, None, :] + elif embedded_speakers.ndim == 2: + embedded_speakers = embedded_speakers[None, :] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = self.last_linear(postnet_outputs) + decoder_outputs = decoder_outputs.transpose(1, 2) + outputs = { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + } + return outputs + + def before_backward_pass(self, loss_dict, optimizer) -> None: + # Extracting custom training specific operations for capacitron + # from the trainer + if self.use_capacitron_vae: + loss_dict["capacitron_vae_beta_loss"].backward() + optimizer.first_step() + + def train_step(self, batch: Dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]: + """Perform a single training step by fetching the right set of samples from the batch. + + Args: + batch ([Dict]): A dictionary of input tensors. + criterion ([torch.nn.Module]): Callable criterion to compute model loss. + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + linear_input = batch["linear_input"] + stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] + speaker_ids = batch["speaker_ids"] + d_vectors = batch["d_vectors"] + + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) + + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) + ) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + linear_input.float(), + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + outputs["capacitron_vae_outputs"] if self.capacitron_vae else None, + mel_lengths, + None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"]) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def get_optimizer(self) -> List: + if self.use_capacitron_vae: + return CapacitronOptimizer(self.config, self.named_parameters()) + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer: object): + opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt) + + def before_gradient_clipping(self): + if self.use_capacitron_vae: + # Capacitron model specific gradient clipping + model_params_to_clip = [] + for name, param in self.named_parameters(): + if param.requires_grad: + if name != "capacitron_vae_layer.beta": + model_params_to_clip.append(param) + torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip) + + def _create_logs(self, batch, outputs, ap): + postnet_outputs = outputs["model_outputs"] + decoder_outputs = outputs["decoder_outputs"] + alignments = outputs["alignments"] + alignments_backward = outputs["alignments_backward"] + mel_input = batch["mel_input"] + linear_input = batch["linear_input"] + + pred_linear_spec = postnet_outputs[0].data.cpu().numpy() + pred_mel_spec = decoder_outputs[0].data.cpu().numpy() + gt_linear_spec = linear_input[0].data.cpu().numpy() + gt_mel_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "pred_linear_spec": plot_spectrogram(pred_linear_spec, ap, output_fig=False), + "real_linear_spec": plot_spectrogram(gt_linear_spec, ap, output_fig=False), + "pred_mel_spec": plot_spectrogram(pred_mel_spec, ap, output_fig=False), + "real_mel_spec": plot_spectrogram(gt_mel_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + audio = ap.inv_spectrogram(pred_linear_spec.T) + return figures, {"audio": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (TacotronConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Tacotron(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py new file mode 100644 index 0000000..71ab1ea --- /dev/null +++ b/TTS/tts/models/tacotron2.py @@ -0,0 +1,433 @@ +# coding: utf-8 + +from typing import Dict, List, Union + +import torch +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE +from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet +from TTS.tts.models.base_tacotron import BaseTacotron +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.capacitron_optimizer import CapacitronOptimizer + + +class Tacotron2(BaseTacotron): + """Tacotron2 model implementation inherited from :class:`TTS.tts.models.base_tacotron.BaseTacotron`. + + Paper:: + https://arxiv.org/abs/1712.05884 + + Paper abstract:: + This paper describes Tacotron 2, a neural network architecture for speech synthesis directly from text. + The system is composed of a recurrent sequence-to-sequence feature prediction network that maps character + embeddings to mel-scale spectrograms, followed by a modified WaveNet model acting as a vocoder to synthesize + timedomain waveforms from those spectrograms. Our model achieves a mean opinion score (MOS) of 4.53 comparable + to a MOS of 4.58 for professionally recorded speech. To validate our design choices, we present ablation + studies of key components of our system and evaluate the impact of using mel spectrograms as the input to + WaveNet instead of linguistic, duration, and F0 features. We further demonstrate that using a compact acoustic + intermediate representation enables significant simplification of the WaveNet architecture. + + Check :class:`TTS.tts.configs.tacotron2_config.Tacotron2Config` for model arguments. + + Args: + config (TacotronConfig): + Configuration for the Tacotron2 model. + speaker_manager (SpeakerManager): + Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None. + """ + + def __init__( + self, + config: "Tacotron2Config", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + self.decoder_output_dim = config.out_channels + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) + + # init multi-speaker layers + if self.use_speaker_embedding or self.use_d_vector_file: + self.init_multispeaker(config) + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim + + if self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim + + # embedding layer + self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0) + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + self.postnet = Postnet(self.out_channels) + + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference + + # global style token layers + if self.gst and self.use_gst: + self.gst_layer = GST( + num_mel=self.decoder_output_dim, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, + ) + + # Capacitron VAE Layers + if self.capacitron_vae and self.use_capacitron_vae: + self.capacitron_vae_layer = CapacitronVAE( + num_mel=self.decoder_output_dim, + encoder_output_dim=self.encoder_in_features, + capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, + speaker_embedding_dim=self.embedded_speaker_dim + if self.capacitron_vae.capacitron_use_speaker_embedding + else None, + text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + ) + + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self.coarse_decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.ddc_r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + + @staticmethod + def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): + """Final reshape of the model output tensors.""" + mel_outputs = mel_outputs.transpose(1, 2) + mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) + return mel_outputs, mel_outputs_postnet, alignments + + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): + """Forward pass for training with Teacher Forcing. + + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + mel_specs: :math:`[B, T_out, C]` + mel_lengths: :math:`[B]` + aux_input: 'speaker_ids': :math:`[B, 1]` and 'd_vectors': :math:`[B, C]` + """ + aux_input = self._format_aux_input(aux_input) + outputs = {"alignments_backward": None, "decoder_outputs_backward": None} + # compute mask for padding + # B x T_in_max (boolean) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) + # B x D_embed x T_in_max + embedded_inputs = self.embedding(text).transpose(1, 2) + # B x T_in_max x D_en + encoder_outputs = self.encoder(embedded_inputs, text_lengths) + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + + # capacitron + if self.capacitron_vae and self.use_capacitron_vae: + # B x capacitron_VAE_embedding_dim + encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[mel_specs, mel_lengths], + text_info=[embedded_inputs.transpose(1, 2), text_lengths] + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, + ) + else: + capacitron_vae_outputs = None + + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + + # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + # sequence masking + if mel_lengths is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x mel_dim x T_out + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs) + # B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + outputs.update( + { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + "capacitron_vae_outputs": capacitron_vae_outputs, + } + ) + return outputs + + @torch.no_grad() + def inference(self, text, aux_input=None): + """Forward pass for inference with no Teacher-Forcing. + + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + """ + aux_input = self._format_aux_input(aux_input) + embedded_inputs = self.embedding(text).transpose(1, 2) + encoder_outputs = self.encoder.inference(embedded_inputs) + + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + + if self.capacitron_vae and self.use_capacitron_vae: + if aux_input["style_text"] is not None: + style_text_embedding = self.embedding(aux_input["style_text"]) + style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to( + encoder_outputs.device + ) # pylint: disable=not-callable + reference_mel_length = ( + torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_mel"] is not None + else None + ) # pylint: disable=not-callable + # B x capacitron_VAE_embedding_dim + encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[aux_input["style_mel"], reference_mel_length] + if aux_input["style_mel"] is not None + else None, + text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, + speaker_embedding=aux_input["d_vectors"] + if self.capacitron_vae.capacitron_use_speaker_embedding + else None, + ) + + if self.num_speakers > 1: + if not self.use_d_vector_file: + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None] + # reshape embedded_speakers + if embedded_speakers.ndim == 1: + embedded_speakers = embedded_speakers[None, None, :] + elif embedded_speakers.ndim == 2: + embedded_speakers = embedded_speakers[None, :] + else: + embedded_speakers = aux_input["d_vectors"] + + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) + outputs = { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + } + return outputs + + def before_backward_pass(self, loss_dict, optimizer) -> None: + # Extracting custom training specific operations for capacitron + # from the trainer + if self.use_capacitron_vae: + loss_dict["capacitron_vae_beta_loss"].backward() + optimizer.first_step() + + def train_step(self, batch: Dict, criterion: torch.nn.Module): + """A single training step. Forward pass and loss computation. + + Args: + batch ([Dict]): A dictionary of input tensors. + criterion ([type]): Callable criterion to compute model loss. + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] + speaker_ids = batch["speaker_ids"] + d_vectors = batch["d_vectors"] + + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) + + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) + ) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + None, + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + outputs["capacitron_vae_outputs"] if self.capacitron_vae else None, + mel_lengths, + None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"]) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def get_optimizer(self) -> List: + if self.use_capacitron_vae: + return CapacitronOptimizer(self.config, self.named_parameters()) + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer: object): + opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt) + + def before_gradient_clipping(self): + if self.use_capacitron_vae: + # Capacitron model specific gradient clipping + model_params_to_clip = [] + for name, param in self.named_parameters(): + if param.requires_grad: + if name != "capacitron_vae_layer.beta": + model_params_to_clip.append(param) + torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip) + + def _create_logs(self, batch, outputs, ap): + """Create dashboard log information.""" + postnet_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + alignments_backward = outputs["alignments_backward"] + mel_input = batch["mel_input"] + + pred_spec = postnet_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (Tacotron2Config): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(new_config, samples) + return Tacotron2(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py new file mode 100644 index 0000000..16644ff --- /dev/null +++ b/TTS/tts/models/tortoise.py @@ -0,0 +1,911 @@ +import os +import random +from contextlib import contextmanager +from dataclasses import dataclass +from time import time + +import torch +import torch.nn.functional as F +import torchaudio +from coqpit import Coqpit +from tqdm import tqdm + +from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram +from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, load_voice, wav_to_univnet_mel +from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice +from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead +from TTS.tts.layers.tortoise.clvp import CLVP +from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps +from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts +from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter +from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.tortoise.vocoder import VocConf, VocType +from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment +from TTS.tts.models.base_tts import BaseTTS + + +def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ + tp = t[..., :length] + if t.shape[-1] == length: + tp = t + elif t.shape[-1] < length: + tp = F.pad(t, (0, length - t.shape[-1])) + return tp + + +def deterministic_state(seed=None): + """ + Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be + reproduced. + """ + seed = int(time()) if seed is None else seed + torch.manual_seed(seed) + random.seed(seed) + # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary. + # torch.use_deterministic_algorithms(True) + + return seed + + +def load_discrete_vocoder_diffuser( + trained_diffusion_steps=4000, + desired_diffusion_steps=200, + cond_free=True, + cond_free_k=1, + sampler="ddim", +): + """ + Helper function to load a GaussianDiffusion instance configured for use as a vocoder. + """ + return SpacedDiffusion( + use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), + model_mean_type="epsilon", + model_var_type="learned_range", + loss_type="mse", + betas=get_named_beta_schedule("linear", trained_diffusion_steps), + conditioning_free=cond_free, + conditioning_free_k=cond_free_k, + sampler=sampler, + ) + + +def format_conditioning(clip, cond_length=132300, device="cuda", **kwargs): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start : rand_start + cond_length] + mel_clip = TorchMelSpectrogram(**kwargs)(clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).to(device) + + +def fix_autoregressive_output(codes, stop_token, complain=True): + """ + This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was + trained on and what the autoregressive code generator creates (which has no padding or end). + This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with + a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE + and copying out the last few codes. + + Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar. + """ + # Strip off the autoregressive stop token and add padding. + stop_token_indices = (codes == stop_token).nonzero() + if len(stop_token_indices) == 0: + if complain: + print( + "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " + "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " + "try breaking up your input text." + ) + return codes + codes[stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[stm:] = 83 + if stm - 3 < codes.shape[0]: + codes[-3] = 45 + codes[-2] = 45 + codes[-1] = 248 + return codes + + +def do_spectrogram_diffusion( + diffusion_model, + diffuser, + latents, + conditioning_latents, + temperature=1, + verbose=True, +): + """ + Uses the specified diffusion model to convert discrete codes into a spectrogram. + """ + with torch.no_grad(): + output_seq_len = ( + latents.shape[1] * 4 * 24000 // 22050 + ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion_model.timestep_independent( + latents, conditioning_latents, output_seq_len, False + ) + + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.sample_loop( + diffusion_model, + output_shape, + noise=noise, + model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, + progress=verbose, + ) + return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] + + +def classify_audio_clip(clip, model_dir): + """ + Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. + :param clip: torch tensor containing audio waveform data (get it from load_audio) + :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. + """ + classifier = AudioMiniEncoderWithClassifierHead( + 2, + spec_dim=1, + embedding_dim=512, + depth=5, + downsample_factor=4, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + base_channels=32, + dropout=0, + kernel_size=5, + distribute_zero_label=False, + ) + classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"))) + clip = clip.cpu().unsqueeze(0) + results = F.softmax(classifier(clip), dim=-1) + return results[0][0] + + +def pick_best_batch_size_for_gpu(): + """ + Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give + you a good shot. + """ + if torch.cuda.is_available(): + _, available = torch.cuda.mem_get_info() + availableGb = available / (1024**3) + batch_size = 1 + if availableGb > 14: + batch_size = 16 + elif availableGb > 10: + batch_size = 8 + elif availableGb > 7: + batch_size = 4 + return batch_size + + +@dataclass +class TortoiseAudioConfig(Coqpit): + sample_rate: int = 22050 + diffusion_sample_rate: int = 24000 + output_sample_rate: int = 24000 + + +@dataclass +class TortoiseArgs(Coqpit): + """A dataclass to represent Tortoise model arguments that define the model structure. + + Args: + autoregressive_batch_size (int): The size of the auto-regressive batch. + enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. + high_vram (bool, optional): Whether to use high VRAM. Defaults to False. + kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. + ar_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. + clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. + diff_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. + num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. + vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet. + + For UnifiedVoice model: + ar_max_mel_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. + ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. + ar_max_conditioning_inputs (int, optional): The maximum conditioning inputs for the autoregressive model. Defaults to 2. + ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. + ar_model_dim (int, optional): The model dimension for the autoregressive model. Defaults to 1024. + ar_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. + ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. + ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. + ar_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. + ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. + + For DiffTTS model: + diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024. + diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10. + diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100. + diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200. + diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024. + diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193. + diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0. + diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False. + diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16. + diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0. + diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0. + + For ConditionalLatentVariablePerseq model: + clvp_dim_text (int): The dimension of the text input for the CLVP module. Defaults to 768. + clvp_dim_speech (int): The dimension of the speech input for the CLVP module. Defaults to 768. + clvp_dim_latent (int): The dimension of the latent representation for the CLVP module. Defaults to 768. + clvp_num_text_tokens (int): The number of text tokens used by the CLVP module. Defaults to 256. + clvp_text_enc_depth (int): The depth of the text encoder in the CLVP module. Defaults to 20. + clvp_text_seq_len (int): The maximum sequence length of the text input for the CLVP module. Defaults to 350. + clvp_text_heads (int): The number of attention heads used by the text encoder in the CLVP module. Defaults to 12. + clvp_num_speech_tokens (int): The number of speech tokens used by the CLVP module. Defaults to 8192. + clvp_speech_enc_depth (int): The depth of the speech encoder in the CLVP module. Defaults to 20. + clvp_speech_heads (int): The number of attention heads used by the speech encoder in the CLVP module. Defaults to 12. + clvp_speech_seq_len (int): The maximum sequence length of the speech input for the CLVP module. Defaults to 430. + clvp_use_xformers (bool): A flag indicating whether the model uses transformers in the CLVP module. Defaults to True. + duration_const (int): A constant value used in the model. Defaults to 102400. + """ + + autoregressive_batch_size: int = 1 + enable_redaction: bool = False + high_vram: bool = False + kv_cache: bool = True + ar_checkpoint: str = None + clvp_checkpoint: str = None + diff_checkpoint: str = None + num_chars: int = 255 + vocoder: VocType = VocConf.Univnet + + # UnifiedVoice params + ar_max_mel_tokens: int = 604 + ar_max_text_tokens: int = 402 + ar_max_conditioning_inputs: int = 2 + ar_layers: int = 30 + ar_model_dim: int = 1024 + ar_heads: int = 16 + ar_number_text_tokens: int = 255 + ar_start_text_token: int = 255 + ar_checkpointing: bool = False + ar_train_solo_embeddings: bool = False + + # DiffTTS params + diff_model_channels: int = 1024 + diff_num_layers: int = 10 + diff_in_channels: int = 100 + diff_out_channels: int = 200 + diff_in_latent_channels: int = 1024 + diff_in_tokens: int = 8193 + diff_dropout: int = 0 + diff_use_fp16: bool = False + diff_num_heads: int = 16 + diff_layer_drop: int = 0 + diff_unconditioned_percentage: int = 0 + + # clvp params + clvp_dim_text: int = 768 + clvp_dim_speech: int = 768 + clvp_dim_latent: int = 768 + clvp_num_text_tokens: int = 256 + clvp_text_enc_depth: int = 20 + clvp_text_seq_len: int = 350 + clvp_text_heads: int = 12 + clvp_num_speech_tokens: int = 8192 + clvp_speech_enc_depth: int = 20 + clvp_speech_heads: int = 12 + clvp_speech_seq_len: int = 430 + clvp_use_xformers: bool = True + # constants + duration_const: int = 102400 + + +class Tortoise(BaseTTS): + """Tortoise model class. + + Currently only supports inference. + + Examples: + >>> from TTS.tts.configs.tortoise_config import TortoiseConfig + >>> from TTS.tts.models.tortoise import Tortoise + >>> config = TortoiseConfig() + >>> model = Tortoise.inif_from_config(config) + >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) + """ + + def __init__(self, config: Coqpit): + super().__init__(config, ap=None, tokenizer=None) + self.mel_norm_path = None + self.config = config + self.ar_checkpoint = self.args.ar_checkpoint + self.diff_checkpoint = self.args.diff_checkpoint # TODO: check if this is even needed + self.models_dir = config.model_dir + self.autoregressive_batch_size = ( + pick_best_batch_size_for_gpu() + if self.args.autoregressive_batch_size is None + else self.args.autoregressive_batch_size + ) + self.enable_redaction = self.args.enable_redaction + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.enable_redaction: + self.aligner = Wav2VecAlignment() + + self.tokenizer = VoiceBpeTokenizer() + + self.autoregressive = UnifiedVoice( + max_mel_tokens=self.args.ar_max_mel_tokens, + max_text_tokens=self.args.ar_max_text_tokens, + max_conditioning_inputs=self.args.ar_max_conditioning_inputs, + layers=self.args.ar_layers, + model_dim=self.args.ar_model_dim, + heads=self.args.ar_heads, + number_text_tokens=self.args.ar_number_text_tokens, + start_text_token=self.args.ar_start_text_token, + checkpointing=self.args.ar_checkpointing, + train_solo_embeddings=self.args.ar_train_solo_embeddings, + ).cpu() + + self.diffusion = DiffusionTts( + model_channels=self.args.diff_model_channels, + num_layers=self.args.diff_num_layers, + in_channels=self.args.diff_in_channels, + out_channels=self.args.diff_out_channels, + in_latent_channels=self.args.diff_in_latent_channels, + in_tokens=self.args.diff_in_tokens, + dropout=self.args.diff_dropout, + use_fp16=self.args.diff_use_fp16, + num_heads=self.args.diff_num_heads, + layer_drop=self.args.diff_layer_drop, + unconditioned_percentage=self.args.diff_unconditioned_percentage, + ).cpu() + + self.clvp = CLVP( + dim_text=self.args.clvp_dim_text, + dim_speech=self.args.clvp_dim_speech, + dim_latent=self.args.clvp_dim_latent, + num_text_tokens=self.args.clvp_num_text_tokens, + text_enc_depth=self.args.clvp_text_enc_depth, + text_seq_len=self.args.clvp_text_seq_len, + text_heads=self.args.clvp_text_heads, + num_speech_tokens=self.args.clvp_num_speech_tokens, + speech_enc_depth=self.args.clvp_speech_enc_depth, + speech_heads=self.args.clvp_speech_heads, + speech_seq_len=self.args.clvp_speech_seq_len, + use_xformers=self.args.clvp_use_xformers, + ).cpu() + + self.vocoder = self.args.vocoder.value.constructor().cpu() + + # Random latent generators (RLGs) are loaded lazily. + self.rlg_auto = None + self.rlg_diffusion = None + + if self.args.high_vram: + self.autoregressive = self.autoregressive.to(self.device) + self.diffusion = self.diffusion.to(self.device) + self.clvp = self.clvp.to(self.device) + self.vocoder = self.vocoder.to(self.device) + self.high_vram = self.args.high_vram + + @contextmanager + def temporary_cuda(self, model): + if self.high_vram: + yield model + else: + m = model.to(self.device) + yield m + m = model.cpu() + + def get_conditioning_latents( + self, + voice_samples, + return_mels=False, + latent_averaging_mode=0, + original_tortoise=False, + ): + """ + Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). + These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic + properties. + :param voice_samples: List of arbitrary reference clips, which should be *pairs* of torch tensors containing arbitrary kHz waveform data. + :param latent_averaging_mode: 0/1/2 for following modes: + 0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples + 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks + 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample + """ + assert latent_averaging_mode in [ + 0, + 1, + 2, + ], "latent_averaging mode has to be one of (0, 1, 2)" + + with torch.no_grad(): + voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples] + + auto_conds = [] + for ls in voice_samples: + auto_conds.append(format_conditioning(ls[0], device=self.device, mel_norm_file=self.mel_norm_path)) + auto_conds = torch.stack(auto_conds, dim=1) + with self.temporary_cuda(self.autoregressive) as ar: + auto_latent = ar.get_conditioning(auto_conds) + + diffusion_conds = [] + + DURS_CONST = self.args.duration_const + for ls in voice_samples: + # The diffuser operates at a sample rate of 24000 (except for the latent inputs) + sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1] + if latent_averaging_mode == 0: + sample = pad_or_truncate(sample, DURS_CONST) + cond_mel = wav_to_univnet_mel( + sample.to(self.device), + do_normalization=False, + device=self.device, + ) + diffusion_conds.append(cond_mel) + else: + from math import ceil + + if latent_averaging_mode == 2: + temp_diffusion_conds = [] + for chunk in range(ceil(sample.shape[1] / DURS_CONST)): + current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST] + current_sample = pad_or_truncate(current_sample, DURS_CONST) + cond_mel = wav_to_univnet_mel( + current_sample.to(self.device), + do_normalization=False, + device=self.device, + ) + if latent_averaging_mode == 1: + diffusion_conds.append(cond_mel) + elif latent_averaging_mode == 2: + temp_diffusion_conds.append(cond_mel) + if latent_averaging_mode == 2: + diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0)) + diffusion_conds = torch.stack(diffusion_conds, dim=1) + + with self.temporary_cuda(self.diffusion) as diffusion: + diffusion_latent = diffusion.get_conditioning(diffusion_conds) + + if return_mels: + return auto_latent, diffusion_latent, auto_conds, diffusion_conds + return auto_latent, diffusion_latent + + def get_random_conditioning_latents(self): + # Lazy-load the RLG models. + if self.rlg_auto is None: + self.rlg_auto = RandomLatentConverter(1024).eval() + self.rlg_auto.load_state_dict( + torch.load( + os.path.join(self.models_dir, "rlg_auto.pth"), + map_location=torch.device("cpu"), + ) + ) + self.rlg_diffusion = RandomLatentConverter(2048).eval() + self.rlg_diffusion.load_state_dict( + torch.load( + os.path.join(self.models_dir, "rlg_diffuser.pth"), + map_location=torch.device("cpu"), + ) + ) + with torch.no_grad(): + return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) + + def synthesize(self, text, config, speaker_id="random", voice_dirs=None, **kwargs): + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (TortoiseConfig): Config with inference parameters. + speaker_id (str): One of the available speaker names. If `random`, it generates a random speaker. + voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None. + **kwargs: Inference settings. See `inference()`. + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + + speaker_id = "random" if speaker_id is None else speaker_id + + if voice_dirs is not None: + voice_dirs = [voice_dirs] + voice_samples, conditioning_latents = load_voice(speaker_id, voice_dirs) + + else: + voice_samples, conditioning_latents = load_voice(speaker_id) + + outputs = self.inference_with_config( + text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents, **kwargs + ) + + return_dict = { + "wav": outputs["wav"], + "deterministic_seed": outputs["deterministic_seed"], + "text_inputs": outputs["text"], + "voice_samples": outputs["voice_samples"], + "conditioning_latents": outputs["conditioning_latents"], + } + + return return_dict + + def inference_with_config(self, text, config, **kwargs): + """ + inference with config + #TODO describe in detail + """ + # Use generally found best tuning knobs for generation. + settings = { + "temperature": config.temperature, + "length_penalty": config.length_penalty, + "repetition_penalty": config.repetition_penalty, + "top_p": config.top_p, + "cond_free_k": config.cond_free_k, + "diffusion_temperature": config.diffusion_temperature, + "sampler": config.sampler, + } + # Presets are defined here. + presets = { + "single_sample": { + "num_autoregressive_samples": 8, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast_old": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 30, + "cond_free": False, + }, + "very_fast": { + "num_autoregressive_samples": 32, + "diffusion_iterations": 30, + "sampler": "dpm++2m", + }, + "fast": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 50, + "sampler": "ddim", + }, + "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80}, + "standard": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 200, + }, + "high_quality": { + "num_autoregressive_samples": 256, + "diffusion_iterations": 400, + }, + } + if "preset" in kwargs: + settings.update(presets[kwargs["preset"]]) + kwargs.pop("preset") + settings.update(kwargs) # allow overriding of preset settings with kwargs + return self.inference(text, **settings) + + def inference( + self, + text, + voice_samples=None, + conditioning_latents=None, + k=1, + verbose=True, + use_deterministic_seed=None, + return_deterministic_state=False, + latent_averaging_mode=0, + # autoregressive generation parameters follow + num_autoregressive_samples=16, + temperature=0.8, + length_penalty=1, + repetition_penalty=2.0, + top_p=0.8, + max_mel_tokens=500, + # diffusion generation parameters follow + diffusion_iterations=100, + cond_free=True, + cond_free_k=2, + diffusion_temperature=1.0, + sampler="ddim", + half=True, + original_tortoise=False, + **hf_generate_kwargs, + ): + """ + This function produces an audio clip of the given text being spoken with the given reference voice. + + Args: + text: (str) Text to be spoken. + voice_samples: (List[Tuple[torch.Tensor]]) List of an arbitrary number of reference clips, which should be tuple-pairs + of torch tensors containing arbitrary kHz waveform data. + conditioning_latents: (Tuple[autoregressive_conditioning_latent, diffusion_conditioning_latent]) A tuple of + (autoregressive_conditioning_latent, diffusion_conditioning_latent), which can be provided in lieu + of voice_samples. This is ignored unless `voice_samples=None`. Conditioning latents can be retrieved + via `get_conditioning_latents()`. + k: (int) The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + latent_averaging_mode: (int) 0/1/2 for following modes: + 0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples + 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks + 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample + verbose: (bool) Whether or not to print log messages indicating the progress of creating a clip. Default=true. + num_autoregressive_samples: (int) Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + temperature: (float) The softmax temperature of the autoregressive model. + length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce + the incidence of long silences or "uhhhhhhs", etc. + top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + max_mel_tokens: (int) Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + typical_sampling: (bool) Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666 + I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but could use some tuning. + typical_mass: (float) The typical_mass parameter from the typical_sampling algorithm. + diffusion_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively + refine the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, however. + cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output of the two + is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and dramatically improves realism. + cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + + Returns: + Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert ( + text_tokens.shape[-1] < 400 + ), "Too much text provided. Break the text up into separate segments and re-try inference." + + if voice_samples is not None: + ( + auto_conditioning, + diffusion_conditioning, + _, + _, + ) = self.get_conditioning_latents( + voice_samples, + return_mels=True, + latent_averaging_mode=latent_averaging_mode, + original_tortoise=original_tortoise, + ) + elif conditioning_latents is not None: + auto_conditioning, diffusion_conditioning = conditioning_latents + else: + ( + auto_conditioning, + diffusion_conditioning, + ) = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.to(self.device) + diffusion_conditioning = diffusion_conditioning.to(self.device) + + diffuser = load_discrete_vocoder_diffuser( + desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler + ) + + # in the case of single_sample, + orig_batch_size = self.autoregressive_batch_size + while num_autoregressive_samples % self.autoregressive_batch_size: + self.autoregressive_batch_size //= 2 + with torch.no_grad(): + samples = [] + num_batches = num_autoregressive_samples // self.autoregressive_batch_size + stop_mel_token = self.autoregressive.stop_mel_token + calm_token = ( + 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + ) + self.autoregressive = self.autoregressive.to(self.device) + if verbose: + print("Generating autoregressive samples..") + with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=half + ): + for b in tqdm(range(num_batches), disable=not verbose): + codes = autoregressive.inference_speech( + auto_conditioning, + text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs, + ) + padding_needed = max_mel_tokens - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) + self.autoregressive_batch_size = orig_batch_size # in the case of single_sample + + clip_results = [] + with self.temporary_cuda(self.clvp) as clvp, torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=half + ): + for batch in tqdm(samples, disable=not verbose): + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + clvp_res = clvp( + text_tokens.repeat(batch.shape[0], 1), + batch, + return_loss=False, + ) + clip_results.append(clvp_res) + + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] + del samples + + # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning + # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these + # results, but will increase memory usage. + with self.temporary_cuda(self.autoregressive) as autoregressive: + best_latents = autoregressive( + auto_conditioning.repeat(k, 1), + text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), + best_results, + torch.tensor( + [best_results.shape[-1] * self.autoregressive.mel_length_compression], + device=text_tokens.device, + ), + return_latent=True, + clip_inputs=False, + ) + del auto_conditioning + + if verbose: + print("Transforming autoregressive outputs into audio..") + wav_candidates = [] + for b in range(best_results.shape[0]): + codes = best_results[b].unsqueeze(0) + latents = best_latents[b].unsqueeze(0) + + # Find the first occurrence of the "calm" token and trim the codes to that. + ctokens = 0 + for code in range(codes.shape[-1]): + if codes[0, code] == calm_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :code] + break + with self.temporary_cuda(self.diffusion) as diffusion: + mel = do_spectrogram_diffusion( + diffusion, + diffuser, + latents, + diffusion_conditioning, + temperature=diffusion_temperature, + verbose=verbose, + ) + with self.temporary_cuda(self.vocoder) as vocoder: + wav = vocoder.inference(mel) + wav_candidates.append(wav.cpu()) + + def potentially_redact(clip, text): + if self.enable_redaction: + return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) + return clip + + wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] + + if len(wav_candidates) > 1: + res = wav_candidates + else: + res = wav_candidates[0] + + return_dict = { + "wav": res, + "deterministic_seed": None, + "text": None, + "voice_samples": None, + "conditioning_latents": None, + } + if return_deterministic_state: + return_dict = { + "wav": res, + "deterministic_seed": deterministic_seed, + "text": text, + "voice_samples": voice_samples, + "conditioning_latents": conditioning_latents, + } + return return_dict + + def forward(self): + raise NotImplementedError("Tortoise Training is not implemented") + + def eval_step(self): + raise NotImplementedError("Tortoise Training is not implemented") + + @staticmethod + def init_from_config(config: "TortoiseConfig", **kwargs): # pylint: disable=unused-argument + return Tortoise(config) + + def load_checkpoint( + self, + config, + checkpoint_dir, + ar_checkpoint_path=None, + diff_checkpoint_path=None, + clvp_checkpoint_path=None, + vocoder_checkpoint_path=None, + eval=False, + strict=True, + **kwargs, + ): # pylint: disable=unused-argument, redefined-builtin + """Load a model checkpoints from a directory. This model is with multiple checkpoint files and it + expects to have all the files to be under the given `checkpoint_dir` with the rigth names. + If eval is True, set the model to eval mode. + + Args: + config (TortoiseConfig): The model config. + checkpoint_dir (str): The directory where the checkpoints are stored. + ar_checkpoint_path (str, optional): The path to the autoregressive checkpoint. Defaults to None. + diff_checkpoint_path (str, optional): The path to the diffusion checkpoint. Defaults to None. + clvp_checkpoint_path (str, optional): The path to the CLVP checkpoint. Defaults to None. + vocoder_checkpoint_path (str, optional): The path to the vocoder checkpoint. Defaults to None. + eval (bool, optional): Whether to set the model to eval mode. Defaults to False. + strict (bool, optional): Whether to load the model strictly. Defaults to True. + """ + if self.models_dir is None: + self.models_dir = checkpoint_dir + ar_path = ar_checkpoint_path or os.path.join(checkpoint_dir, "autoregressive.pth") + diff_path = diff_checkpoint_path or os.path.join(checkpoint_dir, "diffusion_decoder.pth") + clvp_path = clvp_checkpoint_path or os.path.join(checkpoint_dir, "clvp2.pth") + vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth") + self.mel_norm_path = os.path.join(checkpoint_dir, "mel_norms.pth") + + if os.path.exists(ar_path): + # remove keys from the checkpoint that are not in the model + checkpoint = torch.load(ar_path, map_location=torch.device("cpu")) + + # strict set False + # due to removed `bias` and `masked_bias` changes in Transformers + self.autoregressive.load_state_dict(checkpoint, strict=False) + + if os.path.exists(diff_path): + self.diffusion.load_state_dict(torch.load(diff_path), strict=strict) + + if os.path.exists(clvp_path): + self.clvp.load_state_dict(torch.load(clvp_path), strict=strict) + + if os.path.exists(vocoder_checkpoint_path): + self.vocoder.load_state_dict( + config.model_args.vocoder.value.optionally_index( + torch.load( + vocoder_checkpoint_path, + map_location=torch.device("cpu"), + ) + ) + ) + + if eval: + self.autoregressive.post_init_gpt2_config(self.args.kv_cache) + self.autoregressive.eval() + self.diffusion.eval() + self.clvp.eval() + self.vocoder.eval() + + def train_step(self): + raise NotImplementedError("Tortoise Training is not implemented") diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py new file mode 100644 index 0000000..d9b1f59 --- /dev/null +++ b/TTS/tts/models/vits.py @@ -0,0 +1,1999 @@ +import math +import os +from dataclasses import dataclass, field, replace +from itertools import chain +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torchaudio +from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.configs.shared_configs import CharactersConfig +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder +from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint +from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask +from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment +from TTS.utils.io import load_fsspec +from TTS.utils.samplers import BucketBatchSampler +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + +############################## +# IO / Feature extraction +############################## + +# pylint: disable=global-statement +hann_window = {} +mel_basis = {} + + +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + +def load_audio(file_path): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load(file_path) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = amp_to_db(spec) + return spec + + +############################# +# CONFIGS +############################# + + +@dataclass +class VitsAudioConfig(Coqpit): + fft_size: int = 1024 + sample_rate: int = 22050 + win_length: int = 1024 + hop_length: int = 256 + num_mels: int = 80 + mel_fmin: int = 0 + mel_fmax: int = None + + +############################## +# DATASET +############################## + + +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create inverse frequency weights for balancing the dataset. + Use `multi_dict` to scale relative weights.""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + # check if all keys are in the multi_dict + for k in multi_dict: + assert k in unique_attr_names, f"{k} not in {unique_attr_names}" + # scale weights + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + +class VitsDataset(TTSDataset): + def __init__(self, model_args, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pad_id = self.tokenizer.characters.pad_id + self.model_args = model_args + + def __getitem__(self, idx): + item = self.samples[idx] + raw_text = item["text"] + + wav, _ = load_audio(item["audio_file"]) + if self.model_args.encoder_sample_rate is not None: + if wav.size(1) % self.model_args.encoder_sample_rate != 0: + wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)] + + wav_filename = os.path.basename(item["audio_file"]) + + token_ids = self.get_token_ids(idx, item["text"]) + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "audio_unique_name": item["audio_unique_name"], + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + - audio_unique_names: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + return { + "tokens": token_padded, + "token_lens": token_lens, + "token_rel_lens": token_rel_lens, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + "audio_unique_names": batch["audio_unique_name"], + } + + +############################## +# MODEL DEFINITION +############################## + + +@dataclass +class VitsArgs(Coqpit): + """VITS model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels of the decoder. Defaults to 513. + + spec_segment_size (int): + Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. + + hidden_channels (int): + Number of hidden channels of the model. Defaults to 192. + + hidden_channels_ffn_text_encoder (int): + Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256. + + num_heads_text_encoder (int): + Number of attention heads of the text encoder transformer. Defaults to 2. + + num_layers_text_encoder (int): + Number of transformer layers in the text encoder. Defaults to 6. + + kernel_size_text_encoder (int): + Kernel size of the text encoder transformer FFN layers. Defaults to 3. + + dropout_p_text_encoder (float): + Dropout rate of the text encoder. Defaults to 0.1. + + dropout_p_duration_predictor (float): + Dropout rate of the duration predictor. Defaults to 0.1. + + kernel_size_posterior_encoder (int): + Kernel size of the posterior encoder's WaveNet layers. Defaults to 5. + + dilatation_posterior_encoder (int): + Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1. + + num_layers_posterior_encoder (int): + Number of posterior encoder's WaveNet layers. Defaults to 16. + + kernel_size_flow (int): + Kernel size of the Residual Coupling layers of the flow network. Defaults to 5. + + dilatation_flow (int): + Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1. + + num_layers_flow (int): + Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6. + + resblock_type_decoder (str): + Type of the residual block in the decoder network. Defaults to "1". + + resblock_kernel_sizes_decoder (List[int]): + Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`. + + resblock_dilation_sizes_decoder (List[List[int]]): + Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`. + + upsample_rates_decoder (List[int]): + Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these + values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`. + + upsample_initial_channel_decoder (int): + Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512. + + upsample_kernel_sizes_decoder (List[int]): + Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. + + periods_multi_period_discriminator (List[int]): + Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`. + + use_sdp (bool): + Use Stochastic Duration Predictor. Defaults to True. + + noise_scale (float): + Noise scale used for the sample noise tensor in training. Defaults to 1.0. + + inference_noise_scale (float): + Noise scale used for the sample noise tensor in inference. Defaults to 0.667. + + length_scale (float): + Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1. + + noise_scale_dp (float): + Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0. + + inference_noise_scale_dp (float): + Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8. + + max_inference_len (int): + Maximum inference length to limit the memory use. Defaults to None. + + init_discriminator (bool): + Initialize the disciminator network if set True. Set False for inference. Defaults to True. + + use_spectral_norm_disriminator (bool): + Use spectral normalization over weight norm in the discriminator. Defaults to False. + + use_speaker_embedding (bool): + Enable/Disable speaker embedding for multi-speaker models. Defaults to False. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_file (List[str]): + List of paths to the files including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + detach_dp_input (bool): + Detach duration predictor's input from the network for stopping the gradients. Defaults to True. + + use_language_embedding (bool): + Enable/Disable language embedding for multilingual models. Defaults to False. + + embedded_language_dim (int): + Number of language embedding channels. Defaults to 4. + + num_languages (int): + Number of languages for the language embedding layer. Defaults to 0. + + language_ids_file (str): + Path to the language mapping file for the Language Manager. Defaults to None. + + use_speaker_encoder_as_loss (bool): + Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. + + speaker_encoder_config_path (str): + Path to the file speaker encoder config file, to use for SCL. Defaults to "". + + speaker_encoder_model_path (str): + Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + + condition_dp_on_speaker (bool): + Condition the duration predictor on the speaker embedding. Defaults to True. + + freeze_encoder (bool): + Freeze the encoder weigths during training. Defaults to False. + + freeze_DP (bool): + Freeze the duration predictor weigths during training. Defaults to False. + + freeze_PE (bool): + Freeze the posterior encoder weigths during training. Defaults to False. + + freeze_flow_encoder (bool): + Freeze the flow encoder weigths during training. Defaults to False. + + freeze_waveform_decoder (bool): + Freeze the waveform decoder weigths during training. Defaults to False. + + encoder_sample_rate (int): + If not None this sample rate will be used for training the Posterior Encoder, + flow, text_encoder and duration predictor. The decoder part (vocoder) will be + trained with the `config.audio.sample_rate`. Defaults to None. + + interpolate_z (bool): + If `encoder_sample_rate` not None and this parameter True the nearest interpolation + will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate` + to the `config.audio.sample_rate`. If it is False you will need to add extra + `upsample_rates_decoder` to match the shape. Defaults to True. + + """ + + num_chars: int = 100 + out_channels: int = 513 + spec_segment_size: int = 32 + hidden_channels: int = 192 + hidden_channels_ffn_text_encoder: int = 768 + num_heads_text_encoder: int = 2 + num_layers_text_encoder: int = 6 + kernel_size_text_encoder: int = 3 + dropout_p_text_encoder: float = 0.1 + dropout_p_duration_predictor: float = 0.5 + kernel_size_posterior_encoder: int = 5 + dilation_rate_posterior_encoder: int = 1 + num_layers_posterior_encoder: int = 16 + kernel_size_flow: int = 5 + dilation_rate_flow: int = 1 + num_layers_flow: int = 4 + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + use_sdp: bool = True + noise_scale: float = 1.0 + inference_noise_scale: float = 0.667 + length_scale: float = 1 + noise_scale_dp: float = 1.0 + inference_noise_scale_dp: float = 1.0 + max_inference_len: int = None + init_discriminator: bool = True + use_spectral_norm_disriminator: bool = False + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: List[str] = None + speaker_embedding_channels: int = 256 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + detach_dp_input: bool = True + use_language_embedding: bool = False + embedded_language_dim: int = 4 + num_languages: int = 0 + language_ids_file: str = None + use_speaker_encoder_as_loss: bool = False + speaker_encoder_config_path: str = "" + speaker_encoder_model_path: str = "" + condition_dp_on_speaker: bool = True + freeze_encoder: bool = False + freeze_DP: bool = False + freeze_PE: bool = False + freeze_flow_decoder: bool = False + freeze_waveform_decoder: bool = False + encoder_sample_rate: int = None + interpolate_z: bool = True + reinit_DP: bool = False + reinit_text_encoder: bool = False + + +class Vits(BaseTTS): + """VITS TTS model + + Paper:: + https://arxiv.org/pdf/2106.06103.pdf + + Paper Abstract:: + Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel + sampling have been proposed, but their sample quality does not match that of two-stage TTS systems. + In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than + current two-stage models. Our method adopts variational inference augmented with normalizing flows and + an adversarial training process, which improves the expressive power of generative modeling. We also propose a + stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the + uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the + natural one-to-many relationship in which a text input can be spoken in multiple ways + with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS) + on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly + available TTS systems and achieves a MOS comparable to ground truth. + + Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. + + Examples: + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> from TTS.tts.models.vits import Vits + >>> config = VitsConfig() + >>> model = Vits(config) + """ + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager, language_manager) + + self.init_multispeaker(config) + self.init_multilingual(config) + self.init_upsampling() + + self.length_scale = self.args.length_scale + self.noise_scale = self.args.noise_scale + self.inference_noise_scale = self.args.inference_noise_scale + self.inference_noise_scale_dp = self.args.inference_noise_scale_dp + self.noise_scale_dp = self.args.noise_scale_dp + self.max_inference_len = self.args.max_inference_len + self.spec_segment_size = self.args.spec_segment_size + + self.text_encoder = TextEncoder( + self.args.num_chars, + self.args.hidden_channels, + self.args.hidden_channels, + self.args.hidden_channels_ffn_text_encoder, + self.args.num_heads_text_encoder, + self.args.num_layers_text_encoder, + self.args.kernel_size_text_encoder, + self.args.dropout_p_text_encoder, + language_emb_dim=self.embedded_language_dim, + ) + + self.posterior_encoder = PosteriorEncoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_posterior_encoder, + dilation_rate=self.args.dilation_rate_posterior_encoder, + num_layers=self.args.num_layers_posterior_encoder, + cond_channels=self.embedded_speaker_dim, + ) + + self.flow = ResidualCouplingBlocks( + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_flow, + dilation_rate=self.args.dilation_rate_flow, + num_layers=self.args.num_layers_flow, + cond_channels=self.embedded_speaker_dim, + ) + + if self.args.use_sdp: + self.duration_predictor = StochasticDurationPredictor( + self.args.hidden_channels, + 192, + 3, + self.args.dropout_p_duration_predictor, + 4, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, + language_emb_dim=self.embedded_language_dim, + ) + else: + self.duration_predictor = DurationPredictor( + self.args.hidden_channels, + 256, + 3, + self.args.dropout_p_duration_predictor, + cond_channels=self.embedded_speaker_dim, + language_emb_dim=self.embedded_language_dim, + ) + + self.waveform_decoder = HifiganGenerator( + self.args.hidden_channels, + 1, + self.args.resblock_type_decoder, + self.args.resblock_dilation_sizes_decoder, + self.args.resblock_kernel_sizes_decoder, + self.args.upsample_kernel_sizes_decoder, + self.args.upsample_initial_channel_decoder, + self.args.upsample_rates_decoder, + inference_padding=0, + cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if self.args.init_discriminator: + self.disc = VitsDiscriminator( + periods=self.args.periods_multi_period_discriminator, + use_spectral_norm=self.args.use_spectral_norm_disriminator, + ) + + @property + def device(self): + return next(self.parameters()).device + + def init_multispeaker(self, config: Coqpit): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + # TODO: make this a function + if self.args.use_speaker_encoder_as_loss: + if self.speaker_manager.encoder is None and ( + not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path + ): + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + ) + + self.speaker_manager.encoder.eval() + print(" > External Speaker Encoder Loaded !!") + + if ( + hasattr(self.speaker_manager.encoder, "audio_config") + and self.config.audio.sample_rate != self.speaker_manager.encoder.audio_config["sample_rate"] + ): + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], + ) + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + def init_multilingual(self, config: Coqpit): + """Initialize multilingual modules of a model. + + Args: + config (Coqpit): Model configuration. + """ + if self.args.language_ids_file is not None: + self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + + if self.args.use_language_embedding and self.language_manager: + print(" > initialization of language-embedding layers.") + self.num_languages = self.language_manager.num_languages + self.embedded_language_dim = self.args.embedded_language_dim + self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) + torch.nn.init.xavier_uniform_(self.emb_l.weight) + else: + self.embedded_language_dim = 0 + + def init_upsampling(self): + """ + Initialize upsampling modules of a model. + """ + if self.args.encoder_sample_rate: + self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate + self.audio_resampler = torchaudio.transforms.Resample( + orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate + ) # pylint: disable=W0201 + + def on_epoch_start(self, trainer): # pylint: disable=W0613 + """Freeze layers at the beginning of an epoch""" + self._freeze_layers() + # set the device of speaker encoder + if self.args.use_speaker_encoder_as_loss: + self.speaker_manager.encoder = self.speaker_manager.encoder.to(self.device) + + def on_init_end(self, trainer): # pylint: disable=W0613 + """Reinit layes if needed""" + if self.args.reinit_DP: + before_dict = get_module_weights_sum(self.duration_predictor) + # Applies weights_reset recursively to every submodule of the duration predictor + self.duration_predictor.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.duration_predictor) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") + print(" > Duration Predictor was reinit.") + + if self.args.reinit_text_encoder: + before_dict = get_module_weights_sum(self.text_encoder) + # Applies weights_reset recursively to every submodule of the duration predictor + self.text_encoder.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.text_encoder) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") + print(" > Text Encoder was reinit.") + + def get_aux_input(self, aux_input: Dict): + sid, g, lid, _ = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _freeze_layers(self): + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if hasattr(self, "emb_l"): + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g, lid, durations = None, None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) + if g.ndim == 2: + g = g.unsqueeze_(0) + + if "language_ids" in aux_input and aux_input["language_ids"] is not None: + lid = aux_input["language_ids"] + if lid.ndim == 0: + lid = lid.unsqueeze_(0) + + if "durations" in aux_input and aux_input["durations"] is not None: + durations = aux_input["durations"] + + return sid, g, lid, durations + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + + # duration predictor + attn_durations = attn.sum(3) + if self.args.use_sdp: + loss_duration = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + attn_durations, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = loss_duration / torch.sum(x_mask) + else: + attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + log_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) + outputs["loss_duration"] = loss_duration + return outputs, attn + + def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): + spec_segment_size = self.spec_segment_size + if self.args.encoder_sample_rate: + # recompute the slices and spec_segment_size if needed + slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids + spec_segment_size = spec_segment_size * int(self.interpolate_factor) + # interpolate z if needed + if self.args.interpolate_z: + z = torch.nn.functional.interpolate(z, scale_factor=[self.interpolate_factor], mode="linear").squeeze(0) + # recompute the mask if needed + if y_lengths is not None and y_mask is not None: + y_mask = ( + sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) + ) # [B, 1, T_dec_resampled] + + return z, spec_segment_size, slice_ids, y_mask + + def forward( # pylint: disable=dangerous-default-value + self, + x: torch.tensor, + x_lengths: torch.tensor, + y: torch.tensor, + y_lengths: torch.tensor, + waveform: torch.tensor, + aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, + ) -> Dict: + """Forward pass of the model. + + Args: + x (torch.tensor): Batch of input character sequence IDs. + x_lengths (torch.tensor): Batch of input character sequence lengths. + y (torch.tensor): Batch of input spectrograms. + y_lengths (torch.tensor): Batch of input spectrogram lengths. + waveform (torch.tensor): Batch of ground truth waveforms per sample. + aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. + Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. + + Returns: + Dict: model outputs keyed by the output name. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - y: :math:`[B, C, T_spec]` + - y_lengths: :math:`[B]` + - waveform: :math:`[B, 1, T_wav]` + - d_vectors: :math:`[B, C, 1]` + - speaker_ids: :math:`[B]` + - language_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + - m_q: :math:`[B, C, T_dec]` + - logs_q: :math:`[B, C, T_dec]` + - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` + - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + """ + outputs = {} + sid, g, lid, _ = self._set_cond_input(aux_input) + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + # posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + + # flow layers + z_p = self.flow(z, y_mask, g=g) + + # duration predictor + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + + # expand prior + m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + + # select a random feature segment for the waveform decoder + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) + + # interpolate z if needed + z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids) + + o = self.waveform_decoder(z_slice, g=g) + + wav_seg = segment( + waveform, + slice_ids * self.config.audio.hop_length, + spec_segment_size * self.config.audio.hop_length, + pad_short=True, + ) + + if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: + # concate generated and GT waveforms + wavs_batch = torch.cat((wav_seg, o), dim=0) + + # resample audio to speaker encoder sample_rate + # pylint: disable=W0105 + if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch) + + pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True) + + # split generated and GT speaker embeddings + gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + else: + gt_spk_emb, syn_spk_emb = None, None + + outputs.update( + { + "model_outputs": o, + "alignments": attn.squeeze(1), + "m_p": m_p, + "logs_p": logs_p, + "z": z, + "z_p": z_p, + "m_q": m_q, + "logs_q": logs_q, + "waveform_seg": wav_seg, + "gt_spk_emb": gt_spk_emb, + "syn_spk_emb": syn_spk_emb, + "slice_ids": slice_ids, + } + ) + return outputs + + @staticmethod + def _set_x_lengths(x, aux_input): + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + @torch.no_grad() + def inference( + self, + x, + aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, + ): # pylint: disable=dangerous-default-value + """ + Note: + To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - d_vectors: :math:`[B, C]` + - speaker_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + """ + sid, g, lid, durations = self._set_cond_input(aux_input) + x_lengths = self._set_x_lengths(x, aux_input) + + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + if durations is None: + if self.args.use_sdp: + logw = self.duration_predictor( + x, + x_mask, + g=g if self.args.condition_dp_on_speaker else None, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb, + ) + else: + logw = self.duration_predictor( + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + ) + w = torch.exp(logw) * x_mask * self.length_scale + else: + assert durations.shape[-1] == x.shape[-1] + w = durations.unsqueeze(0) + + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] + + attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) + + m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + + # upsampling if needed + z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask) + + o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) + + outputs = { + "model_outputs": o, + "alignments": attn.squeeze(1), + "durations": w_ceil, + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference_voice_conversion( + self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None + ): + """Inference for voice conversion + + Args: + reference_wav (Tensor): Reference wavform. Tensor of shape [B, T] + speaker_id (Tensor): speaker_id of the target speaker. Tensor of shape [B] + d_vector (Tensor): d_vector embedding of target speaker. Tensor of shape `[B, C]` + reference_speaker_id (Tensor): speaker_id of the reference_wav speaker. Tensor of shape [B] + reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]` + """ + # compute spectrograms + y = wav_to_spec( + reference_wav, + self.config.audio.fft_size, + self.config.audio.hop_length, + self.config.audio.win_length, + center=False, + ) + y_lengths = torch.tensor([y.size(-1)]).to(y.device) + speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector + speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector + wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt) + return wav + + def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): + """Forward pass for voice conversion + + TODO: create an end-point for voice conversion + + Args: + y (Tensor): Reference spectrograms. Tensor of shape [B, T, C] + y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B] + speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,] + speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,] + """ + assert self.num_speakers > 0, "num_speakers have to be larger than 0." + # speaker embedding + if self.args.use_speaker_embedding and not self.args.use_d_vector_file: + g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1) + g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1) + elif not self.args.use_speaker_embedding and self.args.use_d_vector_file: + g_src = F.normalize(speaker_cond_src).unsqueeze(-1) + g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) + else: + raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") + + z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Perform a single training step. Run the model forward pass and compute losses. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + + spec_lens = batch["spec_lens"] + + if optimizer_idx == 0: + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + spec = batch["spec"] + + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] + waveform = batch["waveform"] + + # generator pass + outputs = self.forward( + tokens, + token_lenghts, + spec, + spec_lens, + waveform, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + # compute scores and features + scores_disc_fake, _, scores_disc_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] + ) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_real, + scores_disc_fake, + ) + return outputs, loss_dict + + if optimizer_idx == 1: + mel = batch["mel"] + + # compute melspec segment + with autocast(enabled=False): + if self.args.encoder_sample_rate: + spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) + else: + spec_segment_size = self.spec_segment_size + + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True + ) + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.config.audio.fft_size, + sample_rate=self.config.audio.sample_rate, + num_mels=self.config.audio.num_mels, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + fmin=self.config.audio.mel_fmin, + fmax=self.config.audio.mel_fmax, + center=False, + ) + + # compute discriminator scores and features + scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_slice_hat=mel_slice.float(), + mel_slice=mel_slice_hat.float(), + z_p=self.model_outputs_cache["z_p"].float(), + logs_q=self.model_outputs_cache["logs_q"].float(), + m_p=self.model_outputs_cache["m_p"].float(), + logs_p=self.model_outputs_cache["logs_p"].float(), + z_len=spec_lens, + scores_disc_fake=scores_disc_fake, + feats_disc_fake=feats_disc_fake, + feats_disc_real=feats_disc_real, + loss_duration=self.model_outputs_cache["loss_duration"], + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], + ) + + return self.model_outputs_cache, loss_dict + + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + figures = plot_results(y_hat, y, ap, name_prefix) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {f"{name_prefix}/audio": sample_voice} + + alignments = outputs[1]["alignments"] + align_img = alignments[0].data.cpu().numpy().T + + figures.update( + { + "alignment": plot_alignment(align_img, output_fig=False), + } + ) + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + ap (AudioProcessor): audio processor used at training. + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previoud training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(self.ap, batch, outputs, "train") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(self.ap, batch, outputs, "eval") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + "language_name": language_name, + } + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + use_griffin_lim=True, + do_trim_silence=False, + ).values() + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.name_to_id and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if self.language_manager is not None and self.language_manager.name_to_id and self.args.use_language_embedding: + language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + if self.args.encoder_sample_rate: + wav = self.audio_resampler(batch["waveform"]) + else: + wav = batch["waveform"] + + # compute spectrograms + batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) + + if self.args.encoder_sample_rate: + # recompute spec with high sampling rate to the loss + spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) + # remove extra stft frames if needed + if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor): + spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + else: + batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)] + else: + spec_mel = batch["spec"] + + batch["mel"] = spec_to_mel( + spec=spec_mel, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + ) + + if self.args.encoder_sample_rate: + assert batch["spec"].shape[2] == int( + batch["mel"].shape[2] / self.interpolate_factor + ), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + else: + assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + + # compute spectrogram frame lengths + batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() + batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() + + if self.args.encoder_sample_rate: + assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0 + else: + assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 + + # zero the padding frames + batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) + batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) + return batch + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + print(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + + # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] + + if weights is not None: + w_sampler = WeightedRandomSampler(weights, len(weights)) + batch_sampler = BucketBatchSampler( + w_sampler, + data=data_items, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + sort_key=lambda x: os.path.getsize(x["audio_file"]), + drop_last=True, + ) + else: + batch_sampler = None + # sampler for DDP + if batch_sampler is None: + batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + batch_sampler = ( + DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler + ) # TODO: check batch_sampler with multi-gpu + return batch_sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = VitsDataset( + model_args=self.args, + samples=samples, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + if sampler is None: + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + collate_fn=dataset.collate_fn, + drop_last=False, # setting this False might cause issues in AMP training. + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + if num_gpus > 1: + loader = DataLoader( + dataset, + sampler=sampler, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + Returns: + List: optimizers. + """ + # select generator parameters + optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + return [optimizer0, optimizer1] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[0]) + scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[1]) + return [scheduler_D, scheduler_G] + + def get_criterion(self): + """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in + `train_step()`""" + from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel + VitsDiscriminatorLoss, + VitsGeneratorLoss, + ) + + return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] + + def load_checkpoint( + self, config, checkpoint_path, eval=False, strict=True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load the model checkpoint and setup for training or inference""" + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + # compat band-aid for the pre-trained models to not use the encoder baked into the model + # TODO: consider baking the speaker encoder into the model and call it from there. + # as it is probably easier for model distribution. + state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} + + if self.args.encoder_sample_rate is not None and eval: + # audio resampler is not used in inference time + self.audio_resampler = None + + # handle fine-tuning from a checkpoint with additional speakers + if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") + emb_g = state["model"]["emb_g.weight"] + new_row = torch.randn(num_new_speakers, emb_g.shape[1]) + emb_g = torch.cat([emb_g, new_row], axis=0) + state["model"]["emb_g.weight"] = emb_g + # load the model weights + self.load_state_dict(state["model"], strict=strict) + + if eval: + self.eval() + assert not self.training + + def load_fairseq_checkpoint( + self, config, checkpoint_dir, eval=False, strict=True + ): # pylint: disable=unused-argument, redefined-builtin + """Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms + Performs some changes for compatibility. + + Args: + config (Coqpit): 🐸TTS model config. + checkpoint_dir (str): Path to the checkpoint directory. + eval (bool, optional): Set to True for evaluation. Defaults to False. + """ + import json + + from TTS.tts.utils.text.cleaners import basic_cleaners + + self.disc = None + # set paths + config_file = os.path.join(checkpoint_dir, "config.json") + checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth") + vocab_file = os.path.join(checkpoint_dir, "vocab.txt") + # set config params + with open(config_file, "r", encoding="utf-8") as file: + # Load the JSON data as a dictionary + config_org = json.load(file) + self.config.audio.sample_rate = config_org["data"]["sampling_rate"] + # self.config.add_blank = config['add_blank'] + # set tokenizer + vocab = FairseqVocab(vocab_file) + self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels) + self.tokenizer = TTSTokenizer( + use_phonemes=False, + text_cleaner=basic_cleaners, + characters=vocab, + phonemizer=None, + add_blank=config_org["data"]["add_blank"], + use_eos_bos=False, + ) + # load fairseq checkpoint + new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file) + self.load_state_dict(new_chk, strict=strict) + if eval: + self.eval() + assert not self.training + + @staticmethod + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() + + if not config.model_args.encoder_sample_rate: + assert ( + upsample_rate == config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + else: + encoder_to_vocoder_upsampling_factor = config.audio.sample_rate / config.model_args.encoder_sample_rate + effective_hop_length = config.audio.hop_length * encoder_to_vocoder_upsampling_factor + assert ( + upsample_rate == effective_hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" + + ap = AudioProcessor.init_from_config(config, verbose=verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + language_manager = LanguageManager.init_from_config(config) + + if config.model_args.speaker_encoder_model_path: + speaker_manager.init_encoder( + config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + ) + return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + + def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True): + """Export model to ONNX format for inference + + Args: + output_path (str): Path to save the exported model. + verbose (bool): Print verbose information. Defaults to True. + """ + + # rollback values + _forward = self.forward + disc = None + if hasattr(self, "disc"): + disc = self.disc + training = self.training + + # set export mode + self.disc = None + self.eval() + + def onnx_inference(text, text_lengths, scales, sid=None, langid=None): + noise_scale = scales[0] + length_scale = scales[1] + noise_scale_dp = scales[2] + self.noise_scale = noise_scale + self.length_scale = length_scale + self.noise_scale_dp = noise_scale_dp + return self.inference( + text, + aux_input={ + "x_lengths": text_lengths, + "d_vectors": None, + "speaker_ids": sid, + "language_ids": langid, + "durations": None, + }, + )["model_outputs"] + + self.forward = onnx_inference + + # set dummy inputs + dummy_input_length = 100 + sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long) + sequence_lengths = torch.LongTensor([sequences.size(1)]) + scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) + dummy_input = (sequences, sequence_lengths, scales) + input_names = ["input", "input_lengths", "scales"] + + if self.num_speakers > 0: + speaker_id = torch.LongTensor([0]) + dummy_input += (speaker_id,) + input_names.append("sid") + + if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0: + language_id = torch.LongTensor([0]) + dummy_input += (language_id,) + input_names.append("langid") + + # export to ONNX + torch.onnx.export( + model=self, + args=dummy_input, + opset_version=15, + f=output_path, + verbose=verbose, + input_names=input_names, + output_names=["output"], + dynamic_axes={ + "input": {0: "batch_size", 1: "phonemes"}, + "input_lengths": {0: "batch_size"}, + "output": {0: "batch_size", 1: "time1", 2: "time2"}, + }, + ) + + # rollback + self.forward = _forward + if training: + self.train() + if not disc is None: + self.disc = disc + + def load_onnx(self, model_path: str, cuda=False): + import onnxruntime as ort + + providers = [ + "CPUExecutionProvider" + if cuda is False + else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}) + ] + sess_options = ort.SessionOptions() + self.onnx_sess = ort.InferenceSession( + model_path, + sess_options=sess_options, + providers=providers, + ) + + def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None): + """ONNX inference""" + + if isinstance(x, torch.Tensor): + x = x.cpu().numpy() + + if x_lengths is None: + x_lengths = np.array([x.shape[1]], dtype=np.int64) + + if isinstance(x_lengths, torch.Tensor): + x_lengths = x_lengths.cpu().numpy() + scales = np.array( + [self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp], + dtype=np.float32, + ) + input_params = {"input": x, "input_lengths": x_lengths, "scales": scales} + if not speaker_id is None: + input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy() + if not language_id is None: + input_params["langid"] = torch.tensor([language_id]).cpu().numpy() + + audio = self.onnx_sess.run( + ["output"], + input_params, + ) + return audio[0][0] + + +################################## +# VITS CHARACTERS +################################## + + +class VitsCharacters(BaseCharacters): + """Characters class for VITs model for compatibility with pre-trained models""" + + def __init__( + self, + graphemes: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + ipa_characters: str = _phonemes, + ) -> None: + if ipa_characters is not None: + graphemes += ipa_characters + super().__init__(graphemes, punctuations, pad, None, None, "", is_unique=False, is_sorted=True) + + def _create_vocab(self): + self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + # pylint: disable=unnecessary-comprehension + self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + + @staticmethod + def init_from_config(config: Coqpit): + if config.characters is not None: + _pad = config.characters["pad"] + _punctuations = config.characters["punctuations"] + _letters = config.characters["characters"] + _letters_ipa = config.characters["phonemes"] + return ( + VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), + config, + ) + characters = VitsCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=None, + bos=None, + blank=self._blank, + is_unique=False, + is_sorted=True, + ) + + +class FairseqVocab(BaseVocabulary): + def __init__(self, vocab: str): + super(FairseqVocab).__init__() + self.vocab = vocab + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab_file): + with open(vocab_file, encoding="utf-8") as f: + self._vocab = [x.replace("\n", "") for x in f.readlines()] + self.blank = self._vocab[0] + self.pad = " " + self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension + self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py new file mode 100644 index 0000000..8e9d6bd --- /dev/null +++ b/TTS/tts/models/xtts.py @@ -0,0 +1,791 @@ +import os +from dataclasses import dataclass + +import librosa +import torch +import torch.nn.functional as F +import torchaudio +from coqpit import Coqpit + +from TTS.tts.layers.xtts.gpt import GPT +from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder +from TTS.tts.layers.xtts.stream_generator import init_stream_support +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence +from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager +from TTS.tts.models.base_tts import BaseTTS +from TTS.utils.io import load_fsspec + +init_stream_support() + + +def wav_to_mel_cloning( + wav, + mel_norms_file="../experiments/clips_mel_norms.pth", + mel_norms=None, + device=torch.device("cpu"), + n_fft=4096, + hop_length=1024, + win_length=4096, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, +): + """ + Convert waveform to mel-spectrogram with hard-coded parameters for cloning. + + Args: + wav (torch.Tensor): Input waveform tensor. + mel_norms_file (str): Path to mel-spectrogram normalization file. + mel_norms (torch.Tensor): Mel-spectrogram normalization tensor. + device (torch.device): Device to use for computation. + + Returns: + torch.Tensor: Mel-spectrogram tensor. + """ + mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + power=power, + normalized=normalized, + sample_rate=sample_rate, + f_min=f_min, + f_max=f_max, + n_mels=n_mels, + norm="slaney", + ).to(device) + wav = wav.to(device) + mel = mel_stft(wav) + mel = torch.log(torch.clamp(mel, min=1e-5)) + if mel_norms is None: + mel_norms = torch.load(mel_norms_file, map_location=device) + mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +def load_audio(audiopath, sampling_rate): + # better load setting following: https://github.com/faroit/python_audio_loading_benchmark + + # torchaudio should chose proper backend to load audio depending on platform + audio, lsr = torchaudio.load(audiopath) + + # stereo to mono if needed + if audio.size(0) != 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 10) or not torch.any(audio < 0): + print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + # clip audio invalid values + audio.clip_(-1, 1) + return audio + + +def pad_or_truncate(t, length): + """ + Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. + + Args: + t (torch.Tensor): The input tensor to be padded or truncated. + length (int): The desired length of the tensor. + + Returns: + torch.Tensor: The padded or truncated tensor. + """ + tp = t[..., :length] + if t.shape[-1] == length: + tp = t + elif t.shape[-1] < length: + tp = F.pad(t, (0, length - t.shape[-1])) + return tp + + +@dataclass +class XttsAudioConfig(Coqpit): + """ + Configuration class for audio-related parameters in the XTTS model. + + Args: + sample_rate (int): The sample rate in which the GPT operates. + output_sample_rate (int): The sample rate of the output audio waveform. + """ + + sample_rate: int = 22050 + output_sample_rate: int = 24000 + + +@dataclass +class XttsArgs(Coqpit): + """A dataclass to represent XTTS model arguments that define the model structure. + + Args: + gpt_batch_size (int): The size of the auto-regressive batch. + enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. + kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. + gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. + clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. + decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. + num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. + + For GPT model: + gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. + gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. + gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70. + gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. + gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024. + gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. + gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. + gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. + gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. + gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. + gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024. + gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True. + gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False. + """ + + gpt_batch_size: int = 1 + enable_redaction: bool = False + kv_cache: bool = True + gpt_checkpoint: str = None + clvp_checkpoint: str = None + decoder_checkpoint: str = None + num_chars: int = 255 + + # XTTS GPT Encoder params + tokenizer_file: str = "" + gpt_max_audio_tokens: int = 605 + gpt_max_text_tokens: int = 402 + gpt_max_prompt_tokens: int = 70 + gpt_layers: int = 30 + gpt_n_model_channels: int = 1024 + gpt_n_heads: int = 16 + gpt_number_text_tokens: int = None + gpt_start_text_token: int = None + gpt_stop_text_token: int = None + gpt_num_audio_tokens: int = 8194 + gpt_start_audio_token: int = 8192 + gpt_stop_audio_token: int = 8193 + gpt_code_stride_len: int = 1024 + gpt_use_masking_gt_prompt_approach: bool = True + gpt_use_perceiver_resampler: bool = False + + # HifiGAN Decoder params + input_sample_rate: int = 22050 + output_sample_rate: int = 24000 + output_hop_length: int = 256 + decoder_input_dim: int = 1024 + d_vector_dim: int = 512 + cond_d_vector_in_each_upsampling_layer: bool = True + + # constants + duration_const: int = 102400 + + +class Xtts(BaseTTS): + """ⓍTTS model implementation. + + ❗ Currently it only supports inference. + + Examples: + >>> from TTS.tts.configs.xtts_config import XttsConfig + >>> from TTS.tts.models.xtts import Xtts + >>> config = XttsConfig() + >>> model = Xtts.inif_from_config(config) + >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) + """ + + def __init__(self, config: Coqpit): + super().__init__(config, ap=None, tokenizer=None) + self.mel_stats_path = None + self.config = config + self.gpt_checkpoint = self.args.gpt_checkpoint + self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed + self.models_dir = config.model_dir + self.gpt_batch_size = self.args.gpt_batch_size + + self.tokenizer = VoiceBpeTokenizer() + self.gpt = None + self.init_models() + self.register_buffer("mel_stats", torch.ones(80)) + + def init_models(self): + """Initialize the models. We do it here since we need to load the tokenizer first.""" + if self.tokenizer.tokenizer is not None: + self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens() + self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") + self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") + + if self.args.gpt_number_text_tokens: + self.gpt = GPT( + layers=self.args.gpt_layers, + model_dim=self.args.gpt_n_model_channels, + start_text_token=self.args.gpt_start_text_token, + stop_text_token=self.args.gpt_stop_text_token, + heads=self.args.gpt_n_heads, + max_text_tokens=self.args.gpt_max_text_tokens, + max_mel_tokens=self.args.gpt_max_audio_tokens, + max_prompt_tokens=self.args.gpt_max_prompt_tokens, + number_text_tokens=self.args.gpt_number_text_tokens, + num_audio_tokens=self.args.gpt_num_audio_tokens, + start_audio_token=self.args.gpt_start_audio_token, + stop_audio_token=self.args.gpt_stop_audio_token, + use_perceiver_resampler=self.args.gpt_use_perceiver_resampler, + code_stride_len=self.args.gpt_code_stride_len, + ) + + self.hifigan_decoder = HifiDecoder( + input_sample_rate=self.args.input_sample_rate, + output_sample_rate=self.args.output_sample_rate, + output_hop_length=self.args.output_hop_length, + ar_mel_length_compression=self.args.gpt_code_stride_len, + decoder_input_dim=self.args.decoder_input_dim, + d_vector_dim=self.args.d_vector_dim, + cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, + ) + + @property + def device(self): + return next(self.parameters()).device + + @torch.inference_mode() + def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6): + """Compute the conditioning latents for the GPT model from the given audio. + + Args: + audio (tensor): audio tensor. + sr (int): Sample rate of the audio. + length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30. + chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio + is being used without chunking. It must be < `length`. Defaults to 6. + """ + if sr != 22050: + audio = torchaudio.functional.resample(audio, sr, 22050) + if length > 0: + audio = audio[:, : 22050 * length] + if self.args.gpt_use_perceiver_resampler: + style_embs = [] + for i in range(0, audio.shape[1], 22050 * chunk_length): + audio_chunk = audio[:, i : i + 22050 * chunk_length] + + # if the chunk is too short ignore it + if audio_chunk.size(-1) < 22050 * 0.33: + continue + + mel_chunk = wav_to_mel_cloning( + audio_chunk, + mel_norms=self.mel_stats.cpu(), + n_fft=2048, + hop_length=256, + win_length=1024, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None) + style_embs.append(style_emb) + + # mean style embedding + cond_latent = torch.stack(style_embs).mean(dim=0) + else: + mel = wav_to_mel_cloning( + audio, + mel_norms=self.mel_stats.cpu(), + n_fft=4096, + hop_length=1024, + win_length=4096, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + cond_latent = self.gpt.get_style_emb(mel.to(self.device)) + return cond_latent.transpose(1, 2) + + @torch.inference_mode() + def get_speaker_embedding(self, audio, sr): + audio_16k = torchaudio.functional.resample(audio, sr, 16000) + return ( + self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True) + .unsqueeze(-1) + .to(self.device) + ) + + @torch.inference_mode() + def get_conditioning_latents( + self, + audio_path, + max_ref_length=30, + gpt_cond_len=6, + gpt_cond_chunk_len=6, + librosa_trim_db=None, + sound_norm_refs=False, + load_sr=22050, + ): + """Get the conditioning latents for the GPT model from the given audio. + + Args: + audio_path (str or List[str]): Path to reference audio file(s). + max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30. + gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6. + gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6. + librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None. + sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False. + load_sr (int, optional): Sample rate to load the audio. Defaults to 24000. + """ + # deal with multiples references + if not isinstance(audio_path, list): + audio_paths = [audio_path] + else: + audio_paths = audio_path + + speaker_embeddings = [] + audios = [] + speaker_embedding = None + for file_path in audio_paths: + audio = load_audio(file_path, load_sr) + audio = audio[:, : load_sr * max_ref_length].to(self.device) + if sound_norm_refs: + audio = (audio / torch.abs(audio).max()) * 0.75 + if librosa_trim_db is not None: + audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] + + # compute latents for the decoder + speaker_embedding = self.get_speaker_embedding(audio, load_sr) + speaker_embeddings.append(speaker_embedding) + + audios.append(audio) + + # merge all the audios and compute the latents for the gpt + full_audio = torch.cat(audios, dim=-1) + gpt_cond_latents = self.get_gpt_cond_latents( + full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len + ) # [1, 1024, T] + + if speaker_embeddings: + speaker_embedding = torch.stack(speaker_embeddings) + speaker_embedding = speaker_embedding.mean(dim=0) + + return gpt_cond_latents, speaker_embedding + + def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs): + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (XttsConfig): Config with inference parameters. + speaker_wav (list): List of paths to the speaker audio files to be used for cloning. + language (str): Language ID of the speaker. + **kwargs: Inference settings. See `inference()`. + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + assert ( + "zh-cn" if language == "zh" else language in self.config.languages + ), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}" + # Use generally found best tuning knobs for generation. + settings = { + "temperature": config.temperature, + "length_penalty": config.length_penalty, + "repetition_penalty": config.repetition_penalty, + "top_k": config.top_k, + "top_p": config.top_p, + } + settings.update(kwargs) # allow overriding of preset settings with kwargs + if speaker_id is not None: + gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() + return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) + settings.update({ + "gpt_cond_len": config.gpt_cond_len, + "gpt_cond_chunk_len": config.gpt_cond_chunk_len, + "max_ref_len": config.max_ref_len, + "sound_norm_refs": config.sound_norm_refs, + }) + return self.full_inference(text, speaker_wav, language, **settings) + + @torch.inference_mode() + def full_inference( + self, + text, + ref_audio_path, + language, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + # Cloning + gpt_cond_len=30, + gpt_cond_chunk_len=6, + max_ref_len=10, + sound_norm_refs=False, + **hf_generate_kwargs, + ): + """ + This function produces an audio clip of the given text being spoken with the given reference voice. + + Args: + text: (str) Text to be spoken. + + ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3 + seconds long. + + language: (str) Language of the voice to be generated. + + temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65. + + length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the + model to produce more terse outputs. Defaults to 1.0. + + repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during + decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0. + + top_k: (int) K value used in top-k sampling. [0,inf]. Lower values mean the decoder produces more "likely" + (aka boring) outputs. Defaults to 50. + + top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" + (aka boring) outputs. Defaults to 0.8. + + gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used + else the first `gpt_cond_len` secs is used. Defaults to 30 seconds. + + gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds. + + hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive + transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + + Returns: + Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents( + audio_path=ref_audio_path, + gpt_cond_len=gpt_cond_len, + gpt_cond_chunk_len=gpt_cond_chunk_len, + max_ref_length=max_ref_len, + sound_norm_refs=sound_norm_refs, + ) + + return self.inference( + text, + language, + gpt_cond_latent, + speaker_embedding, + temperature=temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + **hf_generate_kwargs, + ) + + @torch.inference_mode() + def inference( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + num_beams=1, + speed=1.0, + enable_text_splitting=False, + **hf_generate_kwargs, + ): + language = language.split("-")[0] # remove the country code + length_scale = 1.0 / max(speed, 0.05) + gpt_cond_latent = gpt_cond_latent.to(self.device) + speaker_embedding = speaker_embedding.to(self.device) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] + + wavs = [] + gpt_latents_list = [] + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + with torch.no_grad(): + gpt_codes = self.gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + num_beams=num_beams, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) + + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = self.gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) + + gpt_latents_list.append(gpt_latents.cpu()) + wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()) + + return { + "wav": torch.cat(wavs, dim=0).numpy(), + "gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(), + "speaker_embedding": speaker_embedding, + } + + def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): + """Handle chunk formatting in streaming mode""" + wav_chunk = wav_gen[:-overlap_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] + if wav_overlap is not None: + # cross fade the overlap section + if overlap_len > len(wav_chunk): + # wav_chunk is smaller than overlap_len, pass on last wav_gen + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :] + else: + # not expecting will hit here as problem happens on last chunk + wav_chunk = wav_gen[-overlap_len:] + return wav_chunk, wav_gen, None + else: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + + wav_overlap = wav_gen[-overlap_len:] + wav_gen_prev = wav_gen + return wav_chunk, wav_gen_prev, wav_overlap + + @torch.inference_mode() + def inference_stream( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + # Streaming + stream_chunk_size=20, + overlap_wav_len=1024, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + speed=1.0, + enable_text_splitting=False, + **hf_generate_kwargs, + ): + language = language.split("-")[0] # remove the country code + length_scale = 1.0 / max(speed, 0.05) + gpt_cond_latent = gpt_cond_latent.to(self.device) + speaker_embedding = speaker_embedding.to(self.device) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] + + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + fake_inputs = self.gpt.compute_embeddings( + gpt_cond_latent.to(self.device), + text_tokens, + ) + gpt_generator = self.gpt.get_generator( + fake_inputs=fake_inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + do_sample=do_sample, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + + last_tokens = [] + all_latents = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + + while not is_end: + try: + x, latent = next(gpt_generator) + last_tokens += [x] + all_latents += [latent] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + last_tokens = [] + yield wav_chunk + + def forward(self): + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + ) + + def eval_step(self): + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + ) + + @staticmethod + def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument + return Xtts(config) + + def eval(self): # pylint: disable=redefined-builtin + """Sets the model to evaluation mode. Overrides the default eval() method to also set the GPT model to eval mode.""" + self.gpt.init_gpt_for_inference() + super().eval() + + def get_compatible_checkpoint_state_dict(self, model_path): + checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] + # remove xtts gpt trainer extra keys + ignore_keys = ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"] + for key in list(checkpoint.keys()): + # check if it is from the coqui Trainer if so convert it + if key.startswith("xtts."): + new_key = key.replace("xtts.", "") + checkpoint[new_key] = checkpoint[key] + del checkpoint[key] + key = new_key + + # remove unused keys + if key.split(".")[0] in ignore_keys: + del checkpoint[key] + + return checkpoint + + def load_checkpoint( + self, + config, + checkpoint_dir=None, + checkpoint_path=None, + vocab_path=None, + eval=True, + strict=True, + use_deepspeed=False, + speaker_file_path=None, + ): + """ + Loads a checkpoint from disk and initializes the model's state and tokenizer. + + Args: + config (dict): The configuration dictionary for the model. + checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. + checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. + vocab_path (str, optional): The path to the vocabulary file. Defaults to None. + eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True. + strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. + + Returns: + None + """ + + model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") + vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") + + if speaker_file_path is None and checkpoint_dir is not None: + speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth") + + self.language_manager = LanguageManager(config) + self.speaker_manager = None + if speaker_file_path is not None and os.path.exists(speaker_file_path): + self.speaker_manager = SpeakerManager(speaker_file_path) + + if os.path.exists(vocab_path): + self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path) + + self.init_models() + + checkpoint = self.get_compatible_checkpoint_state_dict(model_path) + + # deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not + try: + self.load_state_dict(checkpoint, strict=strict) + except: + if eval: + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) + self.load_state_dict(checkpoint, strict=strict) + + if eval: + self.hifigan_decoder.eval() + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) + self.gpt.eval() + + def train_step(self): + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + )