Add files via upload

This commit is contained in:
Sam Khoze
2024-06-18 13:21:08 -07:00
committed by GitHub
parent 6af40bc6cf
commit dc8b8bca5a
97 changed files with 10910 additions and 0 deletions
+39
View File
@@ -0,0 +1,39 @@
# Mozilla TTS Vocoders (Experimental)
Here there are vocoder model implementations which can be combined with the other TTS models.
Currently, following models are implemented:
- Melgan
- MultiBand-Melgan
- ParallelWaveGAN
- GAN-TTS (Discriminator Only)
It is also very easy to adapt different vocoder models as we provide a flexible and modular (but not too modular) framework.
## Training a model
You can see here an example (Soon)[Colab Notebook]() training MelGAN with LJSpeech dataset.
In order to train a new model, you need to gather all wav files into a folder and give this folder to `data_path` in '''config.json'''
You need to define other relevant parameters in your ```config.json``` and then start traning with the following command.
```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --config_path path/to/config.json```
Example config files can be found under `tts/vocoder/configs/` folder.
You can continue a previous training run by the following command.
```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --continue_path path/to/your/model/folder```
You can fine-tune a pre-trained model by the following command.
```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --restore_path path/to/your/model.pth```
Restoring a model starts a new training in a different folder. It only restores model weights with the given checkpoint file. However, continuing a training starts from the same directory where the previous training run left off.
You can also follow your training runs on Tensorboard as you do with our TTS models.
## Acknowledgement
Thanks to @kan-bayashi for his [repository](https://github.com/kan-bayashi/ParallelWaveGAN) being the start point of our work.
View File
Binary file not shown.
+17
View File
@@ -0,0 +1,17 @@
import importlib
import os
from inspect import isclass
# import all files under configs/
configs_dir = os.path.dirname(__file__)
for file in os.listdir(configs_dir):
path = os.path.join(configs_dir, file)
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
config_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("TTS.vocoder.configs." + config_name)
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isclass(attribute):
# Add the class to this package's variables
globals()[attribute_name] = attribute
@@ -0,0 +1,106 @@
from dataclasses import dataclass, field
from .shared_configs import BaseGANVocoderConfig
@dataclass
class FullbandMelganConfig(BaseGANVocoderConfig):
"""Defines parameters for FullBand MelGAN vocoder.
Example:
>>> from TTS.vocoder.configs import FullbandMelganConfig
>>> config = FullbandMelganConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fullband_melgan`.
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
'melgan_multiscale_discriminator`.
discriminator_model_params (dict): The discriminator model parameters. Defaults to
'{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}`
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
considered as a generator too. Defaults to `melgan_generator`.
batch_size (int):
Batch size used at training. Larger values use more memory. Defaults to 16.
seq_len (int):
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
pad_short (int):
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
use_noise_augment (bool):
enable / disable random noise added to the input waveform. The noise is added after computing the
features. Defaults to True.
use_cache (bool):
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
not large enough. Defaults to True.
use_stft_loss (bool):
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
use_subband_stft (bool):
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
use_mse_gan_loss (bool):
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
use_hinge_gan_loss (bool):
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
Defaults to False.
use_feat_match_loss (bool):
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
use_l1_spec_loss (bool):
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
stft_loss_params (dict): STFT loss parameters. Default to
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
model loss. Defaults to 0.5.
subband_stft_loss_weight (float):
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
mse_G_loss_weight (float):
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
hinge_G_loss_weight (float):
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
feat_match_loss_weight (float):
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
l1_spec_loss_weight (float):
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
"""
model: str = "fullband_melgan"
# Model specific params
discriminator_model: str = "melgan_multiscale_discriminator"
discriminator_model_params: dict = field(
default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]}
)
generator_model: str = "melgan_generator"
generator_model_params: dict = field(
default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 4}
)
# Training - overrides
batch_size: int = 16
seq_len: int = 8192
pad_short: int = 2000
use_noise_augment: bool = True
use_cache: bool = True
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
use_subband_stft_loss: bool = False
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = False
stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240],
}
)
# loss weights - overrides
stft_loss_weight: float = 0.5
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 2.5
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 108
l1_spec_loss_weight: float = 0.0
+136
View File
@@ -0,0 +1,136 @@
from dataclasses import dataclass, field
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
@dataclass
class HifiganConfig(BaseGANVocoderConfig):
"""Defines parameters for FullBand MelGAN vocoder.
Example:
>>> from TTS.vocoder.configs import HifiganConfig
>>> config = HifiganConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `hifigan`.
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
'hifigan_discriminator`.
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
considered as a generator too. Defaults to `hifigan_generator`.
generator_model_params (dict): Parameters of the generator model. Defaults to
`
{
"upsample_factors": [8, 8, 2, 2],
"upsample_kernel_sizes": [16, 16, 4, 4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"resblock_type": "1",
}
`
batch_size (int):
Batch size used at training. Larger values use more memory. Defaults to 16.
seq_len (int):
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
pad_short (int):
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
use_noise_augment (bool):
enable / disable random noise added to the input waveform. The noise is added after computing the
features. Defaults to True.
use_cache (bool):
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
not large enough. Defaults to True.
use_stft_loss (bool):
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
use_subband_stft (bool):
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
use_mse_gan_loss (bool):
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
use_hinge_gan_loss (bool):
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
Defaults to False.
use_feat_match_loss (bool):
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
use_l1_spec_loss (bool):
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
stft_loss_params (dict):
STFT loss parameters. Default to
`{
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240]
}`
l1_spec_loss_params (dict):
L1 spectrogram loss parameters. Default to
`{
"use_mel": True,
"sample_rate": 22050,
"n_fft": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mels": 80,
"mel_fmin": 0.0,
"mel_fmax": None,
}`
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
model loss. Defaults to 0.5.
subband_stft_loss_weight (float):
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
mse_G_loss_weight (float):
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
hinge_G_loss_weight (float):
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
feat_match_loss_weight (float):
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
l1_spec_loss_weight (float):
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
"""
model: str = "hifigan"
# model specific params
discriminator_model: str = "hifigan_discriminator"
generator_model: str = "hifigan_generator"
generator_model_params: dict = field(
default_factory=lambda: {
"upsample_factors": [8, 8, 2, 2],
"upsample_kernel_sizes": [16, 16, 4, 4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"resblock_type": "1",
}
)
# LOSS PARAMETERS - overrides
use_stft_loss: bool = False
use_subband_stft_loss: bool = False
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = True
# loss weights - overrides
stft_loss_weight: float = 0
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 1
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 108
l1_spec_loss_weight: float = 45
l1_spec_loss_params: dict = field(
default_factory=lambda: {
"use_mel": True,
"sample_rate": 22050,
"n_fft": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mels": 80,
"mel_fmin": 0.0,
"mel_fmax": None,
}
)
# optimizer parameters
lr: float = 1e-4
wd: float = 1e-6
+106
View File
@@ -0,0 +1,106 @@
from dataclasses import dataclass, field
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
@dataclass
class MelganConfig(BaseGANVocoderConfig):
"""Defines parameters for MelGAN vocoder.
Example:
>>> from TTS.vocoder.configs import MelganConfig
>>> config = MelganConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `melgan`.
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
'melgan_multiscale_discriminator`.
discriminator_model_params (dict): The discriminator model parameters. Defaults to
'{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}`
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
considered as a generator too. Defaults to `melgan_generator`.
batch_size (int):
Batch size used at training. Larger values use more memory. Defaults to 16.
seq_len (int):
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
pad_short (int):
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
use_noise_augment (bool):
enable / disable random noise added to the input waveform. The noise is added after computing the
features. Defaults to True.
use_cache (bool):
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
not large enough. Defaults to True.
use_stft_loss (bool):
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
use_subband_stft (bool):
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
use_mse_gan_loss (bool):
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
use_hinge_gan_loss (bool):
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
Defaults to False.
use_feat_match_loss (bool):
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
use_l1_spec_loss (bool):
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
stft_loss_params (dict): STFT loss parameters. Default to
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
model loss. Defaults to 0.5.
subband_stft_loss_weight (float):
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
mse_G_loss_weight (float):
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
hinge_G_loss_weight (float):
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
feat_match_loss_weight (float):
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
l1_spec_loss_weight (float):
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
"""
model: str = "melgan"
# Model specific params
discriminator_model: str = "melgan_multiscale_discriminator"
discriminator_model_params: dict = field(
default_factory=lambda: {"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}
)
generator_model: str = "melgan_generator"
generator_model_params: dict = field(
default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 3}
)
# Training - overrides
batch_size: int = 16
seq_len: int = 8192
pad_short: int = 2000
use_noise_augment: bool = True
use_cache: bool = True
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
use_subband_stft_loss: bool = False
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = False
stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240],
}
)
# loss weights - overrides
stft_loss_weight: float = 0.5
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 2.5
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 108
l1_spec_loss_weight: float = 0
@@ -0,0 +1,144 @@
from dataclasses import dataclass, field
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
@dataclass
class MultibandMelganConfig(BaseGANVocoderConfig):
"""Defines parameters for MultiBandMelGAN vocoder.
Example:
>>> from TTS.vocoder.configs import MultibandMelganConfig
>>> config = MultibandMelganConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `multiband_melgan`.
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
'melgan_multiscale_discriminator`.
discriminator_model_params (dict): The discriminator model parameters. Defaults to
'{
"base_channels": 16,
"max_channels": 512,
"downsample_factors": [4, 4, 4]
}`
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
considered as a generator too. Defaults to `melgan_generator`.
generator_model_param (dict):
The generator model parameters. Defaults to `{"upsample_factors": [8, 4, 2], "num_res_blocks": 4}`.
use_pqmf (bool):
enable / disable PQMF modulation for multi-band training. Defaults to True.
lr_gen (float):
Initial learning rate for the generator model. Defaults to 0.0001.
lr_disc (float):
Initial learning rate for the discriminator model. Defaults to 0.0001.
optimizer (torch.optim.Optimizer):
Optimizer used for the training. Defaults to `AdamW`.
optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler_gen (torch.optim.Scheduler):
Learning rate scheduler for the generator. Defaults to `MultiStepLR`.
lr_scheduler_gen_params (dict):
Parameters for the generator learning rate scheduler. Defaults to
`{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`.
lr_scheduler_disc (torch.optim.Scheduler):
Learning rate scheduler for the discriminator. Defaults to `MultiStepLR`.
lr_scheduler_dict_params (dict):
Parameters for the discriminator learning rate scheduler. Defaults to
`{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`.
batch_size (int):
Batch size used at training. Larger values use more memory. Defaults to 16.
seq_len (int):
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
pad_short (int):
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
use_noise_augment (bool):
enable / disable random noise added to the input waveform. The noise is added after computing the
features. Defaults to True.
use_cache (bool):
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
not large enough. Defaults to True.
steps_to_start_discriminator (int):
Number of steps required to start training the discriminator. Defaults to 0.
use_stft_loss (bool):`
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
use_subband_stft (bool):
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
use_mse_gan_loss (bool):
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
use_hinge_gan_loss (bool):
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
Defaults to False.
use_feat_match_loss (bool):
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
use_l1_spec_loss (bool):
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
stft_loss_params (dict): STFT loss parameters. Default to
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
model loss. Defaults to 0.5.
subband_stft_loss_weight (float):
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
mse_G_loss_weight (float):
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
hinge_G_loss_weight (float):
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
feat_match_loss_weight (float):
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
l1_spec_loss_weight (float):
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
"""
model: str = "multiband_melgan"
# Model specific params
discriminator_model: str = "melgan_multiscale_discriminator"
discriminator_model_params: dict = field(
default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]}
)
generator_model: str = "multiband_melgan_generator"
generator_model_params: dict = field(default_factory=lambda: {"upsample_factors": [8, 4, 2], "num_res_blocks": 4})
use_pqmf: bool = True
# optimizer - overrides
lr_gen: float = 0.0001 # Initial learning rate.
lr_disc: float = 0.0001 # Initial learning rate.
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
lr_scheduler_gen: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_gen_params: dict = field(
default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}
)
lr_scheduler_disc: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(
default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}
)
# Training - overrides
batch_size: int = 64
seq_len: int = 16384
pad_short: int = 2000
use_noise_augment: bool = False
use_cache: bool = True
steps_to_start_discriminator: bool = 200000
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
use_subband_stft_loss: bool = True
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = False
subband_stft_loss_params: dict = field(
default_factory=lambda: {"n_ffts": [384, 683, 171], "hop_lengths": [30, 60, 10], "win_lengths": [150, 300, 60]}
)
# loss weights - overrides
stft_loss_weight: float = 0.5
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 2.5
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 108
l1_spec_loss_weight: float = 0
@@ -0,0 +1,134 @@
from dataclasses import dataclass, field
from .shared_configs import BaseGANVocoderConfig
@dataclass
class ParallelWaveganConfig(BaseGANVocoderConfig):
"""Defines parameters for ParallelWavegan vocoder.
Args:
model (str):
Model name used for selecting the right configuration at initialization. Defaults to `gan`.
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
'parallel_wavegan_discriminator`.
discriminator_model_params (dict): The discriminator model kwargs. Defaults to
'{"num_layers": 10}`
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
considered as a generator too. Defaults to `parallel_wavegan_generator`.
generator_model_param (dict):
The generator model kwargs. Defaults to `{"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30}`.
batch_size (int):
Batch size used at training. Larger values use more memory. Defaults to 16.
seq_len (int):
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
pad_short (int):
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
use_noise_augment (bool):
enable / disable random noise added to the input waveform. The noise is added after computing the
features. Defaults to True.
use_cache (bool):
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
not large enough. Defaults to True.
steps_to_start_discriminator (int):
Number of steps required to start training the discriminator. Defaults to 0.
use_stft_loss (bool):`
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
use_subband_stft (bool):
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
use_mse_gan_loss (bool):
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
use_hinge_gan_loss (bool):
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
Defaults to False.
use_feat_match_loss (bool):
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
use_l1_spec_loss (bool):
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
stft_loss_params (dict): STFT loss parameters. Default to
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
model loss. Defaults to 0.5.
subband_stft_loss_weight (float):
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
mse_G_loss_weight (float):
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
hinge_G_loss_weight (float):
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
feat_match_loss_weight (float):
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 0.
l1_spec_loss_weight (float):
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
lr_gen (float):
Generator model initial learning rate. Defaults to 0.0002.
lr_disc (float):
Discriminator model initial learning rate. Defaults to 0.0002.
optimizer (torch.optim.Optimizer):
Optimizer used for the training. Defaults to `AdamW`.
optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler_gen (torch.optim.Scheduler):
Learning rate scheduler for the generator. Defaults to `ExponentialLR`.
lr_scheduler_gen_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`.
lr_scheduler_disc (torch.optim.Scheduler):
Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`.
lr_scheduler_dict_params (dict):
Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`.
"""
model: str = "parallel_wavegan"
# Model specific params
discriminator_model: str = "parallel_wavegan_discriminator"
discriminator_model_params: dict = field(default_factory=lambda: {"num_layers": 10})
generator_model: str = "parallel_wavegan_generator"
generator_model_params: dict = field(
default_factory=lambda: {"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30}
)
# Training - overrides
batch_size: int = 6
seq_len: int = 25600
pad_short: int = 2000
use_noise_augment: bool = False
use_cache: bool = True
steps_to_start_discriminator: int = 200000
target_loss: str = "loss_1"
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
use_subband_stft_loss: bool = False
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = False
stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240],
}
)
# loss weights - overrides
stft_loss_weight: float = 0.5
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 2.5
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 0
l1_spec_loss_weight: float = 0
# optimizer overrides
lr_gen: float = 0.0002 # Initial learning rate.
lr_disc: float = 0.0002 # Initial learning rate.
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1})
lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(
default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}
)
scheduler_after_epoch: bool = False
+182
View File
@@ -0,0 +1,182 @@
from dataclasses import dataclass, field
from TTS.config import BaseAudioConfig, BaseTrainingConfig
@dataclass
class BaseVocoderConfig(BaseTrainingConfig):
"""Shared parameters among all the vocoder models.
Args:
audio (BaseAudioConfig):
Audio processor config instance. Defaultsto `BaseAudioConfig()`.
use_noise_augment (bool):
Augment the input audio with random noise. Defaults to False/
eval_split_size (int):
Number of instances used for evaluation. Defaults to 10.
data_path (str):
Root path of the training data. All the audio files found recursively from this root path are used for
training. Defaults to `""`.
feature_path (str):
Root path to the precomputed feature files. Defaults to None.
seq_len (int):
Length of the waveform segments used for training. Defaults to 1000.
pad_short (int):
Extra padding for the waveforms shorter than `seq_len`. Defaults to 0.
conv_path (int):
Extra padding for the feature frames against convolution of the edge frames. Defaults to MISSING.
Defaults to 0.
use_cache (bool):
enable / disable in memory caching of the computed features. If the RAM is not enough, if may cause OOM.
Defaults to False.
epochs (int):
Number of training epochs to. Defaults to 10000.
wd (float):
Weight decay.
optimizer (torch.optim.Optimizer):
Optimizer used for the training. Defaults to `AdamW`.
optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
"""
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
# dataloading
use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms.
eval_split_size: int = 10 # number of samples used for evaluation.
# dataset
data_path: str = "" # root data path. It finds all wav files recursively from there.
feature_path: str = None # if you use precomputed features
seq_len: int = 1000 # signal length used in training.
pad_short: int = 0 # additional padding for short wavs
conv_pad: int = 0 # additional padding against convolutions applied to spectrograms
use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM.
# OPTIMIZER
epochs: int = 10000 # total number of epochs to train.
wd: float = 0.0 # Weight decay weight.
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
@dataclass
class BaseGANVocoderConfig(BaseVocoderConfig):
"""Base config class used among all the GAN based vocoders.
Args:
use_stft_loss (bool):
enable / disable the use of STFT loss. Defaults to True.
use_subband_stft_loss (bool):
enable / disable the use of Subband STFT loss. Defaults to True.
use_mse_gan_loss (bool):
enable / disable the use of Mean Squared Error based GAN loss. Defaults to True.
use_hinge_gan_loss (bool):
enable / disable the use of Hinge GAN loss. Defaults to True.
use_feat_match_loss (bool):
enable / disable feature matching loss. Defaults to True.
use_l1_spec_loss (bool):
enable / disable L1 spectrogram loss. Defaults to True.
stft_loss_weight (float):
Loss weight that multiplies the computed loss value. Defaults to 0.
subband_stft_loss_weight (float):
Loss weight that multiplies the computed loss value. Defaults to 0.
mse_G_loss_weight (float):
Loss weight that multiplies the computed loss value. Defaults to 1.
hinge_G_loss_weight (float):
Loss weight that multiplies the computed loss value. Defaults to 0.
feat_match_loss_weight (float):
Loss weight that multiplies the computed loss value. Defaults to 100.
l1_spec_loss_weight (float):
Loss weight that multiplies the computed loss value. Defaults to 45.
stft_loss_params (dict):
Parameters for the STFT loss. Defaults to `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`.
l1_spec_loss_params (dict):
Parameters for the L1 spectrogram loss. Defaults to
`{
"use_mel": True,
"sample_rate": 22050,
"n_fft": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mels": 80,
"mel_fmin": 0.0,
"mel_fmax": None,
}`
target_loss (str):
Target loss name that defines the quality of the model. Defaults to `G_avg_loss`.
grad_clip (list):
A list of gradient clipping theresholds for each optimizer. Any value less than 0 disables clipping.
Defaults to [5, 5].
lr_gen (float):
Generator model initial learning rate. Defaults to 0.0002.
lr_disc (float):
Discriminator model initial learning rate. Defaults to 0.0002.
lr_scheduler_gen (torch.optim.Scheduler):
Learning rate scheduler for the generator. Defaults to `ExponentialLR`.
lr_scheduler_gen_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
lr_scheduler_disc (torch.optim.Scheduler):
Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`.
lr_scheduler_disc_params (dict):
Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
scheduler_after_epoch (bool):
Whether to update the learning rate schedulers after each epoch. Defaults to True.
use_pqmf (bool):
enable / disable PQMF for subband approximation at training. Defaults to False.
steps_to_start_discriminator (int):
Number of steps required to start training the discriminator. Defaults to 0.
diff_samples_for_G_and_D (bool):
enable / disable use of different training samples for the generator and the discriminator iterations.
Enabling it results in slower iterations but faster convergance in some cases. Defaults to False.
"""
model: str = "gan"
# LOSS PARAMETERS
use_stft_loss: bool = True
use_subband_stft_loss: bool = True
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = True
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
use_l1_spec_loss: bool = True
# loss weights
stft_loss_weight: float = 0
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 1
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 100
l1_spec_loss_weight: float = 45
stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240],
}
)
l1_spec_loss_params: dict = field(
default_factory=lambda: {
"use_mel": True,
"sample_rate": 22050,
"n_fft": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mels": 80,
"mel_fmin": 0.0,
"mel_fmax": None,
}
)
target_loss: str = "loss_0" # loss value to pick the best model to save after each epoch
# optimizer
grad_clip: float = field(default_factory=lambda: [5, 5])
lr_gen: float = 0.0002 # Initial learning rate.
lr_disc: float = 0.0002 # Initial learning rate.
lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
scheduler_after_epoch: bool = True
use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN)
steps_to_start_discriminator = 0 # start training the discriminator after this number of steps.
diff_samples_for_G_and_D: bool = False # use different samples for G and D training steps.
+161
View File
@@ -0,0 +1,161 @@
from dataclasses import dataclass, field
from typing import Dict
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
@dataclass
class UnivnetConfig(BaseGANVocoderConfig):
"""Defines parameters for UnivNet vocoder.
Example:
>>> from TTS.vocoder.configs import UnivNetConfig
>>> config = UnivNetConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `UnivNet`.
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
'UnivNet_discriminator`.
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
considered as a generator too. Defaults to `UnivNet_generator`.
generator_model_params (dict): Parameters of the generator model. Defaults to
`
{
"use_mel": True,
"sample_rate": 22050,
"n_fft": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mels": 80,
"mel_fmin": 0.0,
"mel_fmax": None,
}
`
batch_size (int):
Batch size used at training. Larger values use more memory. Defaults to 32.
seq_len (int):
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
pad_short (int):
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
use_noise_augment (bool):
enable / disable random noise added to the input waveform. The noise is added after computing the
features. Defaults to True.
use_cache (bool):
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
not large enough. Defaults to True.
use_stft_loss (bool):
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
use_subband_stft (bool):
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
use_mse_gan_loss (bool):
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
use_hinge_gan_loss (bool):
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
Defaults to False.
use_feat_match_loss (bool):
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
use_l1_spec_loss (bool):
enable / disable using L1 spectrogram loss originally used by univnet model. Defaults to False.
stft_loss_params (dict):
STFT loss parameters. Default to
`{
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240]
}`
l1_spec_loss_params (dict):
L1 spectrogram loss parameters. Default to
`{
"use_mel": True,
"sample_rate": 22050,
"n_fft": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mels": 80,
"mel_fmin": 0.0,
"mel_fmax": None,
}`
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
model loss. Defaults to 0.5.
subband_stft_loss_weight (float):
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
mse_G_loss_weight (float):
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
hinge_G_loss_weight (float):
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
feat_match_loss_weight (float):
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
l1_spec_loss_weight (float):
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
"""
model: str = "univnet"
batch_size: int = 32
# model specific params
discriminator_model: str = "univnet_discriminator"
generator_model: str = "univnet_generator"
generator_model_params: Dict = field(
default_factory=lambda: {
"in_channels": 64,
"out_channels": 1,
"hidden_channels": 32,
"cond_channels": 80,
"upsample_factors": [8, 8, 4],
"lvc_layers_each_block": 4,
"lvc_kernel_size": 3,
"kpnet_hidden_channels": 64,
"kpnet_conv_size": 3,
"dropout": 0.0,
}
)
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True
use_subband_stft_loss: bool = False
use_mse_gan_loss: bool = True
use_hinge_gan_loss: bool = False
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and univnet)
use_l1_spec_loss: bool = False
# loss weights - overrides
stft_loss_weight: float = 2.5
stft_loss_params: Dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240],
}
)
subband_stft_loss_weight: float = 0
mse_G_loss_weight: float = 1
hinge_G_loss_weight: float = 0
feat_match_loss_weight: float = 0
l1_spec_loss_weight: float = 0
l1_spec_loss_params: Dict = field(
default_factory=lambda: {
"use_mel": True,
"sample_rate": 22050,
"n_fft": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mels": 80,
"mel_fmin": 0.0,
"mel_fmax": None,
}
)
# optimizer parameters
lr_gen: float = 1e-4 # Initial learning rate.
lr_disc: float = 1e-4 # Initial learning rate.
lr_scheduler_gen: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
# lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
# lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
steps_to_start_discriminator: int = 200000
def __post_init__(self):
super().__post_init__()
self.generator_model_params["cond_channels"] = self.audio.num_mels
+90
View File
@@ -0,0 +1,90 @@
from dataclasses import dataclass, field
from TTS.vocoder.configs.shared_configs import BaseVocoderConfig
from TTS.vocoder.models.wavegrad import WavegradArgs
@dataclass
class WavegradConfig(BaseVocoderConfig):
"""Defines parameters for WaveGrad vocoder.
Example:
>>> from TTS.vocoder.configs import WavegradConfig
>>> config = WavegradConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `wavegrad`.
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
considered as a generator too. Defaults to `wavegrad`.
model_params (WavegradArgs): Model parameters. Check `WavegradArgs` for default values.
target_loss (str):
Target loss name that defines the quality of the model. Defaults to `avg_wavegrad_loss`.
epochs (int):
Number of epochs to traing the model. Defaults to 10000.
batch_size (int):
Batch size used at training. Larger values use more memory. Defaults to 96.
seq_len (int):
Audio segment length used at training. Larger values use more memory. Defaults to 6144.
use_cache (bool):
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
not large enough. Defaults to True.
mixed_precision (bool):
enable / disable mixed precision training. Default is True.
eval_split_size (int):
Number of samples used for evalutaion. Defaults to 50.
train_noise_schedule (dict):
Training noise schedule. Defaults to
`{"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000}`
test_noise_schedule (dict):
Inference noise schedule. For a better performance, you may need to use `bin/tune_wavegrad.py` to find a
better schedule. Defaults to
`
{
"min_val": 1e-6,
"max_val": 1e-2,
"num_steps": 50,
}
`
grad_clip (float):
Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 1.0
lr (float):
Initila leraning rate. Defaults to 1e-4.
lr_scheduler (str):
One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`.
lr_scheduler_params (dict):
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`
"""
model: str = "wavegrad"
# Model specific params
generator_model: str = "wavegrad"
model_params: WavegradArgs = field(default_factory=WavegradArgs)
target_loss: str = "loss" # loss value to pick the best model to save after each epoch
# Training - overrides
epochs: int = 10000
batch_size: int = 96
seq_len: int = 6144
use_cache: bool = True
mixed_precision: bool = True
eval_split_size: int = 50
# NOISE SCHEDULE PARAMS
train_noise_schedule: dict = field(default_factory=lambda: {"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000})
test_noise_schedule: dict = field(
default_factory=lambda: { # inference noise schedule. Try TTS/bin/tune_wavegrad.py to find the optimal values.
"min_val": 1e-6,
"max_val": 1e-2,
"num_steps": 50,
}
)
# optimizer overrides
grad_clip: float = 1.0
lr: float = 1e-4 # Initial learning rate.
lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_params: dict = field(
default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}
)
+102
View File
@@ -0,0 +1,102 @@
from dataclasses import dataclass, field
from TTS.vocoder.configs.shared_configs import BaseVocoderConfig
from TTS.vocoder.models.wavernn import WavernnArgs
@dataclass
class WavernnConfig(BaseVocoderConfig):
"""Defines parameters for Wavernn vocoder.
Example:
>>> from TTS.vocoder.configs import WavernnConfig
>>> config = WavernnConfig()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `wavernn`.
mode (str):
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
Gaussian Distribution and `bits` for quantized bits as the model's output.
mulaw (bool):
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
to `True`.
generator_model (str):
One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
considered as a generator too. Defaults to `WaveRNN`.
wavernn_model_params (dict):
kwargs for the WaveRNN model. Defaults to
`{
"rnn_dims": 512,
"fc_dims": 512,
"compute_dims": 128,
"res_out_dims": 128,
"num_res_blocks": 10,
"use_aux_net": True,
"use_upsample_net": True,
"upsample_factors": [4, 8, 8]
}`
batched (bool):
enable / disable the batched inference. It speeds up the inference by splitting the input into segments and
processing the segments in a batch. Then it merges the outputs with a certain overlap and smoothing. If
you set it False, without CUDA, it is too slow to be practical. Defaults to True.
target_samples (int):
Size of the segments in batched mode. Defaults to 11000.
overlap_sampels (int):
Size of the overlap between consecutive segments. Defaults to 550.
batch_size (int):
Batch size used at training. Larger values use more memory. Defaults to 256.
seq_len (int):
Audio segment length used at training. Larger values use more memory. Defaults to 1280.
use_noise_augment (bool):
enable / disable random noise added to the input waveform. The noise is added after computing the
features. Defaults to True.
use_cache (bool):
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
not large enough. Defaults to True.
mixed_precision (bool):
enable / disable mixed precision training. Default is True.
eval_split_size (int):
Number of samples used for evalutaion. Defaults to 50.
num_epochs_before_test (int):
Number of epochs waited to run the next evalution. Since inference takes some time, it is better to
wait some number of epochs not ot waste training time. Defaults to 10.
grad_clip (float):
Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 4.0
lr (float):
Initila leraning rate. Defaults to 1e-4.
lr_scheduler (str):
One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`.
lr_scheduler_params (dict):
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [200000, 400000, 600000]}`
"""
model: str = "wavernn"
# Model specific params
model_args: WavernnArgs = field(default_factory=WavernnArgs)
target_loss: str = "loss"
# Inference
batched: bool = True
target_samples: int = 11000
overlap_samples: int = 550
# Training - overrides
epochs: int = 10000
batch_size: int = 256
seq_len: int = 1280
use_noise_augment: bool = False
use_cache: bool = True
mixed_precision: bool = True
eval_split_size: int = 50
num_epochs_before_test: int = (
10 # number of epochs to wait until the next test run (synthesizing a full audio clip).
)
# optimizer overrides
grad_clip: float = 4.0
lr: float = 1e-4 # Initial learning rate.
lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_params: dict = field(default_factory=lambda: {"gamma": 0.5, "milestones": [200000, 400000, 600000]})
+58
View File
@@ -0,0 +1,58 @@
from typing import List
from coqpit import Coqpit
from torch.utils.data import Dataset
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
if config.model.lower() in "gan":
dataset = GANDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
is_training=not is_eval,
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
elif config.model.lower() == "wavegrad":
dataset = WaveGradDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
is_training=not is_eval,
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
elif config.model.lower() == "wavernn":
dataset = WaveRNNDataset(
ap=ap,
items=data_items,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad=config.model_params.pad,
mode=config.model_params.mode,
mulaw=config.model_params.mulaw,
is_training=not is_eval,
verbose=verbose,
)
else:
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")
return dataset
+152
View File
@@ -0,0 +1,152 @@
import glob
import os
import random
from multiprocessing import Manager
import numpy as np
import torch
from torch.utils.data import Dataset
class GANDataset(Dataset):
"""
GAN Dataset searchs for all the wav files under root path
and converts them to acoustic features on the fly and returns
random segments of (audio, feature) couples.
"""
def __init__(
self,
ap,
items,
seq_len,
hop_len,
pad_short,
conv_pad=2,
return_pairs=False,
is_training=True,
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False,
):
super().__init__()
self.ap = ap
self.item_list = items
self.compute_feat = not isinstance(items[0], (tuple, list))
self.seq_len = seq_len
self.hop_len = hop_len
self.pad_short = pad_short
self.conv_pad = conv_pad
self.return_pairs = return_pairs
self.is_training = is_training
self.return_segments = return_segments
self.use_cache = use_cache
self.use_noise_augment = use_noise_augment
self.verbose = verbose
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
# map G and D instances
self.G_to_D_mappings = list(range(len(self.item_list)))
self.shuffle_mapping()
# cache acoustic features
if use_cache:
self.create_feature_cache()
def create_feature_cache(self):
self.manager = Manager()
self.cache = self.manager.list()
self.cache += [None for _ in range(len(self.item_list))]
@staticmethod
def find_wav_files(path):
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
def __len__(self):
return len(self.item_list)
def __getitem__(self, idx):
"""Return different items for Generator and Discriminator and
cache acoustic features"""
# set the seed differently for each worker
if torch.utils.data.get_worker_info():
random.seed(torch.utils.data.get_worker_info().seed)
if self.return_segments:
item1 = self.load_item(idx)
if self.return_pairs:
idx2 = self.G_to_D_mappings[idx]
item2 = self.load_item(idx2)
return item1, item2
return item1
item1 = self.load_item(idx)
return item1
def _pad_short_samples(self, audio, mel=None):
"""Pad samples shorter than the output sequence length"""
if len(audio) < self.seq_len:
audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0)
if mel is not None and mel.shape[1] < self.feat_frame_len:
pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0]
mel = np.pad(
mel,
([0, 0], [0, self.feat_frame_len - mel.shape[1]]),
mode="constant",
constant_values=pad_value.mean(),
)
return audio, mel
def shuffle_mapping(self):
random.shuffle(self.G_to_D_mappings)
def load_item(self, idx):
"""load (audio, feat) couple"""
if self.compute_feat:
# compute features from wav
wavpath = self.item_list[idx]
# print(wavpath)
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
mel = self.ap.melspectrogram(audio)
audio, mel = self._pad_short_samples(audio, mel)
else:
# load precomputed features
wavpath, feat_path = self.item_list[idx]
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
mel = np.load(feat_path)
audio, mel = self._pad_short_samples(audio, mel)
# correct the audio length wrt padding applied in stft
audio = np.pad(audio, (0, self.hop_len), mode="edge")
audio = audio[: mel.shape[-1] * self.hop_len]
assert (
mel.shape[-1] * self.hop_len == audio.shape[-1]
), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}"
audio = torch.from_numpy(audio).float().unsqueeze(0)
mel = torch.from_numpy(mel).float().squeeze(0)
if self.return_segments:
max_mel_start = mel.shape[1] - self.feat_frame_len
mel_start = random.randint(0, max_mel_start)
mel_end = mel_start + self.feat_frame_len
mel = mel[:, mel_start:mel_end]
audio_start = mel_start * self.hop_len
audio = audio[:, audio_start : audio_start + self.seq_len]
if self.use_noise_augment and self.is_training and self.return_segments:
audio = audio + (1 / 32768) * torch.randn_like(audio)
return (mel, audio)
+75
View File
@@ -0,0 +1,75 @@
import glob
import os
from pathlib import Path
import numpy as np
from coqpit import Coqpit
from tqdm import tqdm
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
"""Process wav and compute mel and quantized wave signal.
It is mainly used by WaveRNN dataloader.
Args:
out_path (str): Parent folder path to save the files.
config (Coqpit): Model config.
ap (AudioProcessor): Audio processor.
"""
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
wav_files = find_wav_files(config.data_path)
for path in tqdm(wav_files):
wav_name = Path(path).stem
quant_path = os.path.join(out_path, "quant", wav_name + ".npy")
mel_path = os.path.join(out_path, "mel", wav_name + ".npy")
y = ap.load_wav(path)
mel = ap.melspectrogram(y)
np.save(mel_path, mel)
if isinstance(config.mode, int):
quant = (
mulaw_encode(wav=y, mulaw_qc=config.mode)
if config.model_args.mulaw
else quantize(x=y, quantize_bits=config.mode)
)
np.save(quant_path, quant)
def find_wav_files(data_path, file_ext="wav"):
wav_paths = glob.glob(os.path.join(data_path, "**", f"*.{file_ext}"), recursive=True)
return wav_paths
def find_feat_files(data_path):
feat_paths = glob.glob(os.path.join(data_path, "**", "*.npy"), recursive=True)
return feat_paths
def load_wav_data(data_path, eval_split_size, file_ext="wav"):
wav_paths = find_wav_files(data_path, file_ext=file_ext)
assert len(wav_paths) > 0, f" [!] {data_path} is empty."
np.random.seed(0)
np.random.shuffle(wav_paths)
return wav_paths[:eval_split_size], wav_paths[eval_split_size:]
def load_wav_feat_data(data_path, feat_path, eval_split_size):
wav_paths = find_wav_files(data_path)
feat_paths = find_feat_files(feat_path)
wav_paths.sort(key=lambda x: Path(x).stem)
feat_paths.sort(key=lambda x: Path(x).stem)
assert len(wav_paths) == len(feat_paths), f" [!] {len(wav_paths)} vs {feat_paths}"
for wav, feat in zip(wav_paths, feat_paths):
wav_name = Path(wav).stem
feat_name = Path(feat).stem
assert wav_name == feat_name
items = list(zip(wav_paths, feat_paths))
np.random.seed(0)
np.random.shuffle(items)
return items[:eval_split_size], items[eval_split_size:]
+151
View File
@@ -0,0 +1,151 @@
import glob
import os
import random
from multiprocessing import Manager
from typing import List, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
class WaveGradDataset(Dataset):
"""
WaveGrad Dataset searchs for all the wav files under root path
and converts them to acoustic features on the fly and returns
random segments of (audio, feature) couples.
"""
def __init__(
self,
ap,
items,
seq_len,
hop_len,
pad_short,
conv_pad=2,
is_training=True,
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False,
):
super().__init__()
self.ap = ap
self.item_list = items
self.seq_len = seq_len if return_segments else None
self.hop_len = hop_len
self.pad_short = pad_short
self.conv_pad = conv_pad
self.is_training = is_training
self.return_segments = return_segments
self.use_cache = use_cache
self.use_noise_augment = use_noise_augment
self.verbose = verbose
if return_segments:
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
# cache acoustic features
if use_cache:
self.create_feature_cache()
def create_feature_cache(self):
self.manager = Manager()
self.cache = self.manager.list()
self.cache += [None for _ in range(len(self.item_list))]
@staticmethod
def find_wav_files(path):
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
def __len__(self):
return len(self.item_list)
def __getitem__(self, idx):
item = self.load_item(idx)
return item
def load_test_samples(self, num_samples: int) -> List[Tuple]:
"""Return test samples.
Args:
num_samples (int): Number of samples to return.
Returns:
List[Tuple]: melspectorgram and audio.
Shapes:
- melspectrogram (Tensor): :math:`[C, T]`
- audio (Tensor): :math:`[T_audio]`
"""
samples = []
return_segments = self.return_segments
self.return_segments = False
for idx in range(num_samples):
mel, audio = self.load_item(idx)
samples.append([mel, audio])
self.return_segments = return_segments
return samples
def load_item(self, idx):
"""load (audio, feat) couple"""
# compute features from wav
wavpath = self.item_list[idx]
if self.use_cache and self.cache[idx] is not None:
audio = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
if self.return_segments:
# correct audio length wrt segment length
if audio.shape[-1] < self.seq_len + self.pad_short:
audio = np.pad(
audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0
)
assert (
audio.shape[-1] >= self.seq_len + self.pad_short
), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
# correct the audio length wrt hop length
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0)
if self.use_cache:
self.cache[idx] = audio
if self.return_segments:
max_start = len(audio) - self.seq_len
start = random.randint(0, max_start)
end = start + self.seq_len
audio = audio[start:end]
if self.use_noise_augment and self.is_training and self.return_segments:
audio = audio + (1 / 32768) * torch.randn_like(audio)
mel = self.ap.melspectrogram(audio)
mel = mel[..., :-1] # ignore the padding
audio = torch.from_numpy(audio).float()
mel = torch.from_numpy(mel).float().squeeze(0)
return (mel, audio)
@staticmethod
def collate_full_clips(batch):
"""This is used in tune_wavegrad.py.
It pads sequences to the max length."""
max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]
max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0]
mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length])
audios = torch.zeros([len(batch), max_audio_length])
for idx, b in enumerate(batch):
mel = b[0]
audio = b[1]
mels[idx, :, : mel.shape[1]] = mel
audios[idx, : audio.shape[0]] = audio
return mels, audios
+118
View File
@@ -0,0 +1,118 @@
import numpy as np
import torch
from torch.utils.data import Dataset
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
class WaveRNNDataset(Dataset):
"""
WaveRNN Dataset searchs for all the wav files under root path
and converts them to acoustic features on the fly.
"""
def __init__(
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
):
super().__init__()
self.ap = ap
self.compute_feat = not isinstance(items[0], (tuple, list))
self.item_list = items
self.seq_len = seq_len
self.hop_len = hop_len
self.mel_len = seq_len // hop_len
self.pad = pad
self.mode = mode
self.mulaw = mulaw
self.is_training = is_training
self.verbose = verbose
self.return_segments = return_segments
assert self.seq_len % self.hop_len == 0
def __len__(self):
return len(self.item_list)
def __getitem__(self, index):
item = self.load_item(index)
return item
def load_test_samples(self, num_samples):
samples = []
return_segments = self.return_segments
self.return_segments = False
for idx in range(num_samples):
mel, audio, _ = self.load_item(idx)
samples.append([mel, audio])
self.return_segments = return_segments
return samples
def load_item(self, index):
"""
load (audio, feat) couple if feature_path is set
else compute it on the fly
"""
if self.compute_feat:
wavpath = self.item_list[index]
audio = self.ap.load_wav(wavpath)
if self.return_segments:
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
else:
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
if audio.shape[0] < min_audio_len:
print(" [!] Instance is too short! : {}".format(wavpath))
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
mel = self.ap.melspectrogram(audio)
if self.mode in ["gauss", "mold"]:
x_input = audio
elif isinstance(self.mode, int):
x_input = (
mulaw_encode(wav=audio, mulaw_qc=self.mode)
if self.mulaw
else quantize(x=audio, quantize_bits=self.mode)
)
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
else:
wavpath, feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if mel.shape[-1] < self.mel_len + 2 * self.pad:
print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if self.mode in ["gauss", "mold"]:
x_input = self.ap.load_wav(wavpath)
elif isinstance(self.mode, int):
x_input = np.load(feat_path.replace("/mel/", "/quant/"))
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
return mel, x_input, wavpath
def collate(self, batch):
mel_win = self.seq_len // self.hop_len + 2 * self.pad
max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch]
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets]
mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)]
mels = np.stack(mels).astype(np.float32)
if self.mode in ["gauss", "mold"]:
coarse = np.stack(coarse).astype(np.float32)
coarse = torch.FloatTensor(coarse)
x_input = coarse[:, : self.seq_len]
elif isinstance(self.mode, int):
coarse = np.stack(coarse).astype(np.int64)
coarse = torch.LongTensor(coarse)
x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0
y_coarse = coarse[:, 1:]
mels = torch.FloatTensor(mels)
return x_input, mels, y_coarse
View File
+56
View File
@@ -0,0 +1,56 @@
from torch import nn
from torch.nn.utils.parametrize import remove_parametrizations
# pylint: disable=dangerous-default-value
class ResStack(nn.Module):
def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]):
super().__init__()
resstack = []
for dilation in dilations:
resstack += [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)
),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(padding),
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
]
self.resstack = nn.Sequential(*resstack)
self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
def forward(self, x):
x1 = self.shortcut(x)
x2 = self.resstack(x)
return x1 + x2
def remove_weight_norm(self):
remove_parametrizations(self.shortcut, "weight")
remove_parametrizations(self.resstack[2], "weight")
remove_parametrizations(self.resstack[5], "weight")
remove_parametrizations(self.resstack[8], "weight")
remove_parametrizations(self.resstack[11], "weight")
remove_parametrizations(self.resstack[14], "weight")
remove_parametrizations(self.resstack[17], "weight")
class MRF(nn.Module):
def __init__(self, kernels, channel, dilations=[1, 3, 5]): # # pylint: disable=dangerous-default-value
super().__init__()
self.resblock1 = ResStack(kernels[0], channel, 0, dilations)
self.resblock2 = ResStack(kernels[1], channel, 6, dilations)
self.resblock3 = ResStack(kernels[2], channel, 12, dilations)
def forward(self, x):
x1 = self.resblock1(x)
x2 = self.resblock2(x)
x3 = self.resblock3(x)
return x1 + x2 + x3
def remove_weight_norm(self):
self.resblock1.remove_weight_norm()
self.resblock2.remove_weight_norm()
self.resblock3.remove_weight_norm()
+368
View File
@@ -0,0 +1,368 @@
from typing import Dict, Union
import torch
from torch import nn
from torch.nn import functional as F
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
#################################
# GENERATOR LOSSES
#################################
class STFTLoss(nn.Module):
"""STFT loss. Input generate and real waveforms are converted
to spectrograms compared with L1 and Spectral convergence losses.
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
def __init__(self, n_fft, hop_length, win_length):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.stft = TorchSTFT(n_fft, hop_length, win_length)
def forward(self, y_hat, y):
y_hat_M = self.stft(y_hat)
y_M = self.stft(y)
# magnitude loss
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
# spectral convergence loss
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
return loss_mag, loss_sc
class MultiScaleSTFTLoss(torch.nn.Module):
"""Multi-scale STFT loss. Input generate and real waveforms are converted
to spectrograms compared with L1 and Spectral convergence losses.
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)):
super().__init__()
self.loss_funcs = torch.nn.ModuleList()
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
def forward(self, y_hat, y):
N = len(self.loss_funcs)
loss_sc = 0
loss_mag = 0
for f in self.loss_funcs:
lm, lsc = f(y_hat, y)
loss_mag += lm
loss_sc += lsc
loss_sc /= N
loss_mag /= N
return loss_mag, loss_sc
class L1SpecLoss(nn.Module):
"""L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""
def __init__(
self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True
):
super().__init__()
self.use_mel = use_mel
self.stft = TorchSTFT(
n_fft,
hop_length,
win_length,
sample_rate=sample_rate,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax,
n_mels=n_mels,
use_mel=use_mel,
)
def forward(self, y_hat, y):
y_hat_M = self.stft(y_hat)
y_M = self.stft(y)
# magnitude loss
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
return loss_mag
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
"""Multiscale STFT loss for multi band model outputs.
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""
# pylint: disable=no-self-use
def forward(self, y_hat, y):
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
y = y.view(-1, 1, y.shape[2])
return super().forward(y_hat.squeeze(1), y.squeeze(1))
class MSEGLoss(nn.Module):
"""Mean Squared Generator Loss"""
# pylint: disable=no-self-use
def forward(self, score_real):
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
return loss_fake
class HingeGLoss(nn.Module):
"""Hinge Discriminator Loss"""
# pylint: disable=no-self-use
def forward(self, score_real):
# TODO: this might be wrong
loss_fake = torch.mean(F.relu(1.0 - score_real))
return loss_fake
##################################
# DISCRIMINATOR LOSSES
##################################
class MSEDLoss(nn.Module):
"""Mean Squared Discriminator Loss"""
def __init__(
self,
):
super().__init__()
self.loss_func = nn.MSELoss()
# pylint: disable=no-self-use
def forward(self, score_fake, score_real):
loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape))
loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape))
loss_d = loss_real + loss_fake
return loss_d, loss_real, loss_fake
class HingeDLoss(nn.Module):
"""Hinge Discriminator Loss"""
# pylint: disable=no-self-use
def forward(self, score_fake, score_real):
loss_real = torch.mean(F.relu(1.0 - score_real))
loss_fake = torch.mean(F.relu(1.0 + score_fake))
loss_d = loss_real + loss_fake
return loss_d, loss_real, loss_fake
class MelganFeatureLoss(nn.Module):
def __init__(
self,
):
super().__init__()
self.loss_func = nn.L1Loss()
# pylint: disable=no-self-use
def forward(self, fake_feats, real_feats):
loss_feats = 0
num_feats = 0
for idx, _ in enumerate(fake_feats):
for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]):
loss_feats += self.loss_func(fake_feat, real_feat)
num_feats += 1
loss_feats = loss_feats / num_feats
return loss_feats
#####################################
# LOSS WRAPPERS
#####################################
def _apply_G_adv_loss(scores_fake, loss_func):
"""Compute G adversarial loss function
and normalize values"""
adv_loss = 0
if isinstance(scores_fake, list):
for score_fake in scores_fake:
fake_loss = loss_func(score_fake)
adv_loss += fake_loss
adv_loss /= len(scores_fake)
else:
fake_loss = loss_func(scores_fake)
adv_loss = fake_loss
return adv_loss
def _apply_D_loss(scores_fake, scores_real, loss_func):
"""Compute D loss func and normalize loss values"""
loss = 0
real_loss = 0
fake_loss = 0
if isinstance(scores_fake, list):
# multi-scale loss
for score_fake, score_real in zip(scores_fake, scores_real):
total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real)
loss += total_loss
real_loss += real_loss_
fake_loss += fake_loss_
# normalize loss values with number of scales (discriminators)
loss /= len(scores_fake)
real_loss /= len(scores_real)
fake_loss /= len(scores_fake)
else:
# single scale loss
total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real)
loss = total_loss
return loss, real_loss, fake_loss
##################################
# MODEL LOSSES
##################################
class GeneratorLoss(nn.Module):
"""Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes
losses. It allows to experiment with different combinations of loss functions with different models by just
changing configurations.
Args:
C (AttrDict): model configuration.
"""
def __init__(self, C):
super().__init__()
assert not (
C.use_mse_gan_loss and C.use_hinge_gan_loss
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False
self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False
self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False
self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False
self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False
self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False
self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0
self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0
self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0
self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0
self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0
if C.use_stft_loss:
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
if C.use_subband_stft_loss:
self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params)
if C.use_mse_gan_loss:
self.mse_loss = MSEGLoss()
if C.use_hinge_gan_loss:
self.hinge_loss = HingeGLoss()
if C.use_feat_match_loss:
self.feat_match_loss = MelganFeatureLoss()
if C.use_l1_spec_loss:
assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"]
self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)
def forward(
self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None
):
gen_loss = 0
adv_loss = 0
return_dict = {}
# STFT Loss
if self.use_stft_loss:
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1))
return_dict["G_stft_loss_mg"] = stft_loss_mg
return_dict["G_stft_loss_sc"] = stft_loss_sc
gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
# L1 Spec loss
if self.use_l1_spec_loss:
l1_spec_loss = self.l1_spec_loss(y_hat, y)
return_dict["G_l1_spec_loss"] = l1_spec_loss
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
# subband STFT Loss
if self.use_subband_stft_loss:
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg
return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc
gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
# multiscale MSE adversarial loss
if self.use_mse_gan_loss and scores_fake is not None:
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
return_dict["G_mse_fake_loss"] = mse_fake_loss
adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss
# multiscale Hinge adversarial loss
if self.use_hinge_gan_loss and not scores_fake is not None:
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
return_dict["G_hinge_fake_loss"] = hinge_fake_loss
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
# Feature Matching Loss
if self.use_feat_match_loss and not feats_fake is None:
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
return_dict["G_feat_match_loss"] = feat_match_loss
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
return_dict["loss"] = gen_loss + adv_loss
return_dict["G_gen_loss"] = gen_loss
return_dict["G_adv_loss"] = adv_loss
return return_dict
class DiscriminatorLoss(nn.Module):
"""Like ```GeneratorLoss```"""
def __init__(self, C):
super().__init__()
assert not (
C.use_mse_gan_loss and C.use_hinge_gan_loss
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_mse_gan_loss = C.use_mse_gan_loss
self.use_hinge_gan_loss = C.use_hinge_gan_loss
if C.use_mse_gan_loss:
self.mse_loss = MSEDLoss()
if C.use_hinge_gan_loss:
self.hinge_loss = HingeDLoss()
def forward(self, scores_fake, scores_real):
loss = 0
return_dict = {}
if self.use_mse_gan_loss:
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss
)
return_dict["D_mse_gan_loss"] = mse_D_loss
return_dict["D_mse_gan_real_loss"] = mse_D_real_loss
return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss
loss += mse_D_loss
if self.use_hinge_gan_loss:
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss
)
return_dict["D_hinge_gan_loss"] = hinge_D_loss
return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
loss += hinge_D_loss
return_dict["loss"] = loss
return return_dict
class WaveRNNLoss(nn.Module):
def __init__(self, wave_rnn_mode: Union[str, int]):
super().__init__()
if wave_rnn_mode == "mold":
self.loss_func = discretized_mix_logistic_loss
elif wave_rnn_mode == "gauss":
self.loss_func = gaussian_loss
elif isinstance(wave_rnn_mode, int):
self.loss_func = torch.nn.CrossEntropyLoss()
else:
raise ValueError(" [!] Unknown mode for Wavernn.")
def forward(self, y_hat, y) -> Dict:
loss = self.loss_func(y_hat, y)
return {"loss": loss}
+198
View File
@@ -0,0 +1,198 @@
import torch
import torch.nn.functional as F
class KernelPredictor(torch.nn.Module):
"""Kernel predictor for the location-variable convolutions"""
def __init__( # pylint: disable=dangerous-default-value
self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
):
"""
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int):
kpnet_
"""
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
l_b = conv_out_channels * conv_layers
padding = (kpnet_conv_size - 1) // 2
self.input_conv = torch.nn.Sequential(
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_conv = torch.nn.Sequential(
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True)
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True)
def forward(self, c):
"""
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
"""
batch, _, cond_length = c.shape
c = self.input_conv(c)
c = c + self.residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(
batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length
)
bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length)
return kernels, bias
class LVCBlock(torch.nn.Module):
"""the location-variable convolutions"""
def __init__(
self,
in_channels,
cond_channels,
upsample_ratio,
conv_layers=4,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
):
super().__init__()
self.cond_hop_length = cond_hop_length
self.conv_layers = conv_layers
self.conv_kernel_size = conv_kernel_size
self.convs = torch.nn.ModuleList()
self.upsample = torch.nn.ConvTranspose1d(
in_channels,
in_channels,
kernel_size=upsample_ratio * 2,
stride=upsample_ratio,
padding=upsample_ratio // 2 + upsample_ratio % 2,
output_padding=upsample_ratio % 2,
)
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=conv_layers,
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout,
)
for i in range(conv_layers):
padding = (3**i) * int((conv_kernel_size - 1) / 2)
conv = torch.nn.Conv1d(
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i
)
self.convs.append(conv)
def forward(self, x, c):
"""forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
"""
in_channels = x.shape[1]
kernels, bias = self.kernel_predictor(c)
x = F.leaky_relu(x, 0.2)
x = self.upsample(x)
for i in range(self.conv_layers):
y = F.leaky_relu(x, 0.2)
y = self.convs[i](y)
y = F.leaky_relu(y, 0.2)
k = kernels[:, i, :, :, :, :]
b = bias[:, i, :, :]
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
return x
@staticmethod
def location_variable_convolution(x, kernel, bias, dilation, hop_size):
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
"""
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (
kernel_length * hop_size
), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0)
x = x.unfold(
3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o + bias.unsqueeze(-1).unsqueeze(-1)
o = o.contiguous().view(batch, out_channels, -1)
return o
+43
View File
@@ -0,0 +1,43 @@
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
class ResidualStack(nn.Module):
def __init__(self, channels, num_res_blocks, kernel_size):
super().__init__()
assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
base_padding = (kernel_size - 1) // 2
self.blocks = nn.ModuleList()
for idx in range(num_res_blocks):
layer_kernel_size = kernel_size
layer_dilation = layer_kernel_size**idx
layer_padding = base_padding * layer_dilation
self.blocks += [
nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(layer_padding),
weight_norm(
nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True)
),
nn.LeakyReLU(0.2),
weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
)
]
self.shortcuts = nn.ModuleList(
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
)
def forward(self, x):
for block, shortcut in zip(self.blocks, self.shortcuts):
x = shortcut(x) + block(x)
return x
def remove_weight_norm(self):
for block, shortcut in zip(self.blocks, self.shortcuts):
remove_parametrizations(block[2], "weight")
remove_parametrizations(block[4], "weight")
remove_parametrizations(shortcut, "weight")
+77
View File
@@ -0,0 +1,77 @@
import torch
from torch.nn import functional as F
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(
self,
kernel_size=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
dilation=1,
bias=True,
use_causal_conv=False,
):
super().__init__()
self.dropout = dropout
# no future time stamps available
if use_causal_conv:
padding = (kernel_size - 1) * dilation
else:
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
padding = (kernel_size - 1) // 2 * dilation
self.use_causal_conv = use_causal_conv
# dilation conv
self.conv = torch.nn.Conv1d(
res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias
)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
else:
self.conv1x1_aux = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
def forward(self, x, c):
"""
x: B x D_res x T
c: B x D_aux x T
"""
residual = x
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv(x)
# remove future time steps if use_causal_conv conv
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
# split into two part for gated activation
splitdim = 1
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
# local conditioning
if c is not None:
assert self.conv1x1_aux is not None
c = self.conv1x1_aux(c)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ca, xb + cb
x = torch.tanh(xa) * torch.sigmoid(xb)
# for skip connection
s = self.conv1x1_skip(x)
# for residual connection
x = (self.conv1x1_out(x) + residual) * (0.5**2)
return x, s
+53
View File
@@ -0,0 +1,53 @@
import numpy as np
import torch
import torch.nn.functional as F
from scipy import signal as sig
# adapted from
# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan
class PQMF(torch.nn.Module):
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
super().__init__()
self.N = N
self.taps = taps
self.cutoff = cutoff
self.beta = beta
QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta))
H = np.zeros((N, len(QMF)))
G = np.zeros((N, len(QMF)))
for k in range(N):
constant_factor = (
(2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2))
) # TODO: (taps - 1) -> taps
phase = (-1) ** k * np.pi / 4
H[k] = 2 * QMF * np.cos(constant_factor + phase)
G[k] = 2 * QMF * np.cos(constant_factor - phase)
H = torch.from_numpy(H[:, None, :]).float()
G = torch.from_numpy(G[None, :, :]).float()
self.register_buffer("H", H)
self.register_buffer("G", G)
updown_filter = torch.zeros((N, N, N)).float()
for k in range(N):
updown_filter[k, k, 0] = 1.0
self.register_buffer("updown_filter", updown_filter)
self.N = N
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def forward(self, x):
return self.analysis(x)
def analysis(self, x):
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
def synthesis(self, x):
x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N)
x = F.conv1d(x, self.G, padding=self.taps // 2)
return x
+640
View File
@@ -0,0 +1,640 @@
0.0000000e+000
-5.5252865e-004
-5.6176926e-004
-4.9475181e-004
-4.8752280e-004
-4.8937912e-004
-5.0407143e-004
-5.2265643e-004
-5.4665656e-004
-5.6778026e-004
-5.8709305e-004
-6.1327474e-004
-6.3124935e-004
-6.5403334e-004
-6.7776908e-004
-6.9416146e-004
-7.1577365e-004
-7.2550431e-004
-7.4409419e-004
-7.4905981e-004
-7.6813719e-004
-7.7248486e-004
-7.8343323e-004
-7.7798695e-004
-7.8036647e-004
-7.8014496e-004
-7.7579773e-004
-7.6307936e-004
-7.5300014e-004
-7.3193572e-004
-7.2153920e-004
-6.9179375e-004
-6.6504151e-004
-6.3415949e-004
-5.9461189e-004
-5.5645764e-004
-5.1455722e-004
-4.6063255e-004
-4.0951215e-004
-3.5011759e-004
-2.8969812e-004
-2.0983373e-004
-1.4463809e-004
-6.1733441e-005
1.3494974e-005
1.0943831e-004
2.0430171e-004
2.9495311e-004
4.0265402e-004
5.1073885e-004
6.2393761e-004
7.4580259e-004
8.6084433e-004
9.8859883e-004
1.1250155e-003
1.2577885e-003
1.3902495e-003
1.5443220e-003
1.6868083e-003
1.8348265e-003
1.9841141e-003
2.1461584e-003
2.3017255e-003
2.4625617e-003
2.6201759e-003
2.7870464e-003
2.9469448e-003
3.1125421e-003
3.2739613e-003
3.4418874e-003
3.6008268e-003
3.7603923e-003
3.9207432e-003
4.0819753e-003
4.2264269e-003
4.3730720e-003
4.5209853e-003
4.6606461e-003
4.7932561e-003
4.9137604e-003
5.0393023e-003
5.1407354e-003
5.2461166e-003
5.3471681e-003
5.4196776e-003
5.4876040e-003
5.5475715e-003
5.5938023e-003
5.6220643e-003
5.6455197e-003
5.6389200e-003
5.6266114e-003
5.5917129e-003
5.5404364e-003
5.4753783e-003
5.3838976e-003
5.2715759e-003
5.1382275e-003
4.9839688e-003
4.8109469e-003
4.6039530e-003
4.3801862e-003
4.1251642e-003
3.8456408e-003
3.5401247e-003
3.2091886e-003
2.8446758e-003
2.4508540e-003
2.0274176e-003
1.5784683e-003
1.0902329e-003
5.8322642e-004
2.7604519e-005
-5.4642809e-004
-1.1568136e-003
-1.8039473e-003
-2.4826724e-003
-3.1933778e-003
-3.9401124e-003
-4.7222596e-003
-5.5337211e-003
-6.3792293e-003
-7.2615817e-003
-8.1798233e-003
-9.1325330e-003
-1.0115022e-002
-1.1131555e-002
-1.2185000e-002
-1.3271822e-002
-1.4390467e-002
-1.5540555e-002
-1.6732471e-002
-1.7943338e-002
-1.9187243e-002
-2.0453179e-002
-2.1746755e-002
-2.3068017e-002
-2.4416099e-002
-2.5787585e-002
-2.7185943e-002
-2.8607217e-002
-3.0050266e-002
-3.1501761e-002
-3.2975408e-002
-3.4462095e-002
-3.5969756e-002
-3.7481285e-002
-3.9005368e-002
-4.0534917e-002
-4.2064909e-002
-4.3609754e-002
-4.5148841e-002
-4.6684303e-002
-4.8216572e-002
-4.9738576e-002
-5.1255616e-002
-5.2763075e-002
-5.4245277e-002
-5.5717365e-002
-5.7161645e-002
-5.8591568e-002
-5.9983748e-002
-6.1345517e-002
-6.2685781e-002
-6.3971590e-002
-6.5224711e-002
-6.6436751e-002
-6.7607599e-002
-6.8704383e-002
-6.9763024e-002
-7.0762871e-002
-7.1700267e-002
-7.2568258e-002
-7.3362026e-002
-7.4100364e-002
-7.4745256e-002
-7.5313734e-002
-7.5800836e-002
-7.6199248e-002
-7.6499217e-002
-7.6709349e-002
-7.6817398e-002
-7.6823001e-002
-7.6720492e-002
-7.6505072e-002
-7.6174832e-002
-7.5730576e-002
-7.5157626e-002
-7.4466439e-002
-7.3640601e-002
-7.2677464e-002
-7.1582636e-002
-7.0353307e-002
-6.8966401e-002
-6.7452502e-002
-6.5769067e-002
-6.3944481e-002
-6.1960278e-002
-5.9816657e-002
-5.7515269e-002
-5.5046003e-002
-5.2409382e-002
-4.9597868e-002
-4.6630331e-002
-4.3476878e-002
-4.0145828e-002
-3.6641812e-002
-3.2958393e-002
-2.9082401e-002
-2.5030756e-002
-2.0799707e-002
-1.6370126e-002
-1.1762383e-002
-6.9636862e-003
-1.9765601e-003
3.2086897e-003
8.5711749e-003
1.4128883e-002
1.9883413e-002
2.5822729e-002
3.1953127e-002
3.8277657e-002
4.4780682e-002
5.1480418e-002
5.8370533e-002
6.5440985e-002
7.2694330e-002
8.0137293e-002
8.7754754e-002
9.5553335e-002
1.0353295e-001
1.1168269e-001
1.2000780e-001
1.2850029e-001
1.3715518e-001
1.4597665e-001
1.5496071e-001
1.6409589e-001
1.7338082e-001
1.8281725e-001
1.9239667e-001
2.0212502e-001
2.1197359e-001
2.2196527e-001
2.3206909e-001
2.4230169e-001
2.5264803e-001
2.6310533e-001
2.7366340e-001
2.8432142e-001
2.9507167e-001
3.0590986e-001
3.1682789e-001
3.2781137e-001
3.3887227e-001
3.4999141e-001
3.6115899e-001
3.7237955e-001
3.8363500e-001
3.9492118e-001
4.0623177e-001
4.1756969e-001
4.2891199e-001
4.4025538e-001
4.5159965e-001
4.6293081e-001
4.7424532e-001
4.8552531e-001
4.9677083e-001
5.0798175e-001
5.1912350e-001
5.3022409e-001
5.4125534e-001
5.5220513e-001
5.6307891e-001
5.7385241e-001
5.8454032e-001
5.9511231e-001
6.0557835e-001
6.1591099e-001
6.2612427e-001
6.3619801e-001
6.4612697e-001
6.5590163e-001
6.6551399e-001
6.7496632e-001
6.8423533e-001
6.9332824e-001
7.0223887e-001
7.1094104e-001
7.1944626e-001
7.2774489e-001
7.3582118e-001
7.4368279e-001
7.5131375e-001
7.5870808e-001
7.6586749e-001
7.7277809e-001
7.7942875e-001
7.8583531e-001
7.9197358e-001
7.9784664e-001
8.0344858e-001
8.0876950e-001
8.1381913e-001
8.1857760e-001
8.2304199e-001
8.2722753e-001
8.3110385e-001
8.3469374e-001
8.3797173e-001
8.4095414e-001
8.4362383e-001
8.4598185e-001
8.4803158e-001
8.4978052e-001
8.5119715e-001
8.5230470e-001
8.5310209e-001
8.5357206e-001
8.5373856e-001
8.5357206e-001
8.5310209e-001
8.5230470e-001
8.5119715e-001
8.4978052e-001
8.4803158e-001
8.4598185e-001
8.4362383e-001
8.4095414e-001
8.3797173e-001
8.3469374e-001
8.3110385e-001
8.2722753e-001
8.2304199e-001
8.1857760e-001
8.1381913e-001
8.0876950e-001
8.0344858e-001
7.9784664e-001
7.9197358e-001
7.8583531e-001
7.7942875e-001
7.7277809e-001
7.6586749e-001
7.5870808e-001
7.5131375e-001
7.4368279e-001
7.3582118e-001
7.2774489e-001
7.1944626e-001
7.1094104e-001
7.0223887e-001
6.9332824e-001
6.8423533e-001
6.7496632e-001
6.6551399e-001
6.5590163e-001
6.4612697e-001
6.3619801e-001
6.2612427e-001
6.1591099e-001
6.0557835e-001
5.9511231e-001
5.8454032e-001
5.7385241e-001
5.6307891e-001
5.5220513e-001
5.4125534e-001
5.3022409e-001
5.1912350e-001
5.0798175e-001
4.9677083e-001
4.8552531e-001
4.7424532e-001
4.6293081e-001
4.5159965e-001
4.4025538e-001
4.2891199e-001
4.1756969e-001
4.0623177e-001
3.9492118e-001
3.8363500e-001
3.7237955e-001
3.6115899e-001
3.4999141e-001
3.3887227e-001
3.2781137e-001
3.1682789e-001
3.0590986e-001
2.9507167e-001
2.8432142e-001
2.7366340e-001
2.6310533e-001
2.5264803e-001
2.4230169e-001
2.3206909e-001
2.2196527e-001
2.1197359e-001
2.0212502e-001
1.9239667e-001
1.8281725e-001
1.7338082e-001
1.6409589e-001
1.5496071e-001
1.4597665e-001
1.3715518e-001
1.2850029e-001
1.2000780e-001
1.1168269e-001
1.0353295e-001
9.5553335e-002
8.7754754e-002
8.0137293e-002
7.2694330e-002
6.5440985e-002
5.8370533e-002
5.1480418e-002
4.4780682e-002
3.8277657e-002
3.1953127e-002
2.5822729e-002
1.9883413e-002
1.4128883e-002
8.5711749e-003
3.2086897e-003
-1.9765601e-003
-6.9636862e-003
-1.1762383e-002
-1.6370126e-002
-2.0799707e-002
-2.5030756e-002
-2.9082401e-002
-3.2958393e-002
-3.6641812e-002
-4.0145828e-002
-4.3476878e-002
-4.6630331e-002
-4.9597868e-002
-5.2409382e-002
-5.5046003e-002
-5.7515269e-002
-5.9816657e-002
-6.1960278e-002
-6.3944481e-002
-6.5769067e-002
-6.7452502e-002
-6.8966401e-002
-7.0353307e-002
-7.1582636e-002
-7.2677464e-002
-7.3640601e-002
-7.4466439e-002
-7.5157626e-002
-7.5730576e-002
-7.6174832e-002
-7.6505072e-002
-7.6720492e-002
-7.6823001e-002
-7.6817398e-002
-7.6709349e-002
-7.6499217e-002
-7.6199248e-002
-7.5800836e-002
-7.5313734e-002
-7.4745256e-002
-7.4100364e-002
-7.3362026e-002
-7.2568258e-002
-7.1700267e-002
-7.0762871e-002
-6.9763024e-002
-6.8704383e-002
-6.7607599e-002
-6.6436751e-002
-6.5224711e-002
-6.3971590e-002
-6.2685781e-002
-6.1345517e-002
-5.9983748e-002
-5.8591568e-002
-5.7161645e-002
-5.5717365e-002
-5.4245277e-002
-5.2763075e-002
-5.1255616e-002
-4.9738576e-002
-4.8216572e-002
-4.6684303e-002
-4.5148841e-002
-4.3609754e-002
-4.2064909e-002
-4.0534917e-002
-3.9005368e-002
-3.7481285e-002
-3.5969756e-002
-3.4462095e-002
-3.2975408e-002
-3.1501761e-002
-3.0050266e-002
-2.8607217e-002
-2.7185943e-002
-2.5787585e-002
-2.4416099e-002
-2.3068017e-002
-2.1746755e-002
-2.0453179e-002
-1.9187243e-002
-1.7943338e-002
-1.6732471e-002
-1.5540555e-002
-1.4390467e-002
-1.3271822e-002
-1.2185000e-002
-1.1131555e-002
-1.0115022e-002
-9.1325330e-003
-8.1798233e-003
-7.2615817e-003
-6.3792293e-003
-5.5337211e-003
-4.7222596e-003
-3.9401124e-003
-3.1933778e-003
-2.4826724e-003
-1.8039473e-003
-1.1568136e-003
-5.4642809e-004
2.7604519e-005
5.8322642e-004
1.0902329e-003
1.5784683e-003
2.0274176e-003
2.4508540e-003
2.8446758e-003
3.2091886e-003
3.5401247e-003
3.8456408e-003
4.1251642e-003
4.3801862e-003
4.6039530e-003
4.8109469e-003
4.9839688e-003
5.1382275e-003
5.2715759e-003
5.3838976e-003
5.4753783e-003
5.5404364e-003
5.5917129e-003
5.6266114e-003
5.6389200e-003
5.6455197e-003
5.6220643e-003
5.5938023e-003
5.5475715e-003
5.4876040e-003
5.4196776e-003
5.3471681e-003
5.2461166e-003
5.1407354e-003
5.0393023e-003
4.9137604e-003
4.7932561e-003
4.6606461e-003
4.5209853e-003
4.3730720e-003
4.2264269e-003
4.0819753e-003
3.9207432e-003
3.7603923e-003
3.6008268e-003
3.4418874e-003
3.2739613e-003
3.1125421e-003
2.9469448e-003
2.7870464e-003
2.6201759e-003
2.4625617e-003
2.3017255e-003
2.1461584e-003
1.9841141e-003
1.8348265e-003
1.6868083e-003
1.5443220e-003
1.3902495e-003
1.2577885e-003
1.1250155e-003
9.8859883e-004
8.6084433e-004
7.4580259e-004
6.2393761e-004
5.1073885e-004
4.0265402e-004
2.9495311e-004
2.0430171e-004
1.0943831e-004
1.3494974e-005
-6.1733441e-005
-1.4463809e-004
-2.0983373e-004
-2.8969812e-004
-3.5011759e-004
-4.0951215e-004
-4.6063255e-004
-5.1455722e-004
-5.5645764e-004
-5.9461189e-004
-6.3415949e-004
-6.6504151e-004
-6.9179375e-004
-7.2153920e-004
-7.3193572e-004
-7.5300014e-004
-7.6307936e-004
-7.7579773e-004
-7.8014496e-004
-7.8036647e-004
-7.7798695e-004
-7.8343323e-004
-7.7248486e-004
-7.6813719e-004
-7.4905981e-004
-7.4409419e-004
-7.2550431e-004
-7.1577365e-004
-6.9416146e-004
-6.7776908e-004
-6.5403334e-004
-6.3124935e-004
-6.1327474e-004
-5.8709305e-004
-5.6778026e-004
-5.4665656e-004
-5.2265643e-004
-5.0407143e-004
-4.8937912e-004
-4.8752280e-004
-4.9475181e-004
-5.6176926e-004
-5.5252865e-004
+102
View File
@@ -0,0 +1,102 @@
import torch
from torch.nn import functional as F
class Stretch2d(torch.nn.Module):
def __init__(self, x_scale, y_scale, mode="nearest"):
super().__init__()
self.x_scale = x_scale
self.y_scale = y_scale
self.mode = mode
def forward(self, x):
"""
x (Tensor): Input tensor (B, C, F, T).
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
"""
return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
class UpsampleNetwork(torch.nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
upsample_factors,
nonlinear_activation=None,
nonlinear_activation_params={},
interpolate_mode="nearest",
freq_axis_kernel_size=1,
use_causal_conv=False,
):
super().__init__()
self.use_causal_conv = use_causal_conv
self.up_layers = torch.nn.ModuleList()
for scale in upsample_factors:
# interpolation layer
stretch = Stretch2d(scale, 1, interpolate_mode)
self.up_layers += [stretch]
# conv layer
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
if use_causal_conv:
padding = (freq_axis_padding, scale * 2)
else:
padding = (freq_axis_padding, scale)
conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
self.up_layers += [conv]
# nonlinear
if nonlinear_activation is not None:
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
self.up_layers += [nonlinear]
def forward(self, c):
"""
c : (B, C, T_in).
Tensor: (B, C, T_upsample)
"""
c = c.unsqueeze(1) # (B, 1, C, T)
for f in self.up_layers:
c = f(c)
return c.squeeze(1) # (B, C, T')
class ConvUpsample(torch.nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
upsample_factors,
nonlinear_activation=None,
nonlinear_activation_params={},
interpolate_mode="nearest",
freq_axis_kernel_size=1,
aux_channels=80,
aux_context_window=0,
use_causal_conv=False,
):
super().__init__()
self.aux_context_window = aux_context_window
self.use_causal_conv = use_causal_conv and aux_context_window > 0
# To capture wide-context information in conditional features
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
self.upsample = UpsampleNetwork(
upsample_factors=upsample_factors,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
interpolate_mode=interpolate_mode,
freq_axis_kernel_size=freq_axis_kernel_size,
use_causal_conv=use_causal_conv,
)
def forward(self, c):
"""
c : (B, C, T_in).
Tensor: (B, C, T_upsampled),
"""
c_ = self.conv_in(c)
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
return self.upsample(c)
+166
View File
@@ -0,0 +1,166 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
class Conv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
nn.init.orthogonal_(self.weight)
nn.init.zeros_(self.bias)
class PositionalEncoding(nn.Module):
"""Positional encoding with noise level conditioning"""
def __init__(self, n_channels, max_len=10000):
super().__init__()
self.n_channels = n_channels
self.max_len = max_len
self.C = 5000
self.pe = torch.zeros(0, 0)
def forward(self, x, noise_level):
if x.shape[2] > self.pe.shape[1]:
self.init_pe_matrix(x.shape[1], x.shape[2], x)
return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C
def init_pe_matrix(self, n_channels, max_len, x):
pe = torch.zeros(max_len, n_channels)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
pe[:, 0::2] = torch.sin(position / div_term)
pe[:, 1::2] = torch.cos(position / div_term)
self.pe = pe.transpose(0, 1).to(x)
class FiLM(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.encoding = PositionalEncoding(input_size)
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
nn.init.xavier_uniform_(self.input_conv.weight)
nn.init.xavier_uniform_(self.output_conv.weight)
nn.init.zeros_(self.input_conv.bias)
nn.init.zeros_(self.output_conv.bias)
def forward(self, x, noise_scale):
o = self.input_conv(x)
o = F.leaky_relu(o, 0.2)
o = self.encoding(o, noise_scale)
shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
return shift, scale
def remove_weight_norm(self):
remove_parametrizations(self.input_conv, "weight")
remove_parametrizations(self.output_conv, "weight")
def apply_weight_norm(self):
self.input_conv = weight_norm(self.input_conv)
self.output_conv = weight_norm(self.output_conv)
@torch.jit.script
def shif_and_scale(x, scale, shift):
o = shift + scale * x
return o
class UBlock(nn.Module):
def __init__(self, input_size, hidden_size, factor, dilation):
super().__init__()
assert isinstance(dilation, (list, tuple))
assert len(dilation) == 4
self.factor = factor
self.res_block = Conv1d(input_size, hidden_size, 1)
self.main_block = nn.ModuleList(
[
Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]),
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]),
]
)
self.out_block = nn.ModuleList(
[
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]),
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]),
]
)
def forward(self, x, shift, scale):
x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
res = self.res_block(x_inter)
o = F.leaky_relu(x_inter, 0.2)
o = F.interpolate(o, size=x.shape[-1] * self.factor)
o = self.main_block[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block[1](o)
res2 = res + o
o = shif_and_scale(res2, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.out_block[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.out_block[1](o)
o = o + res2
return o
def remove_weight_norm(self):
remove_parametrizations(self.res_block, "weight")
for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
remove_parametrizations(layer, "weight")
for _, layer in enumerate(self.out_block):
if len(layer.state_dict()) != 0:
remove_parametrizations(layer, "weight")
def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block)
for idx, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
self.main_block[idx] = weight_norm(layer)
for idx, layer in enumerate(self.out_block):
if len(layer.state_dict()) != 0:
self.out_block[idx] = weight_norm(layer)
class DBlock(nn.Module):
def __init__(self, input_size, hidden_size, factor):
super().__init__()
self.factor = factor
self.res_block = Conv1d(input_size, hidden_size, 1)
self.main_block = nn.ModuleList(
[
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
]
)
def forward(self, x):
size = x.shape[-1] // self.factor
res = self.res_block(x)
res = F.interpolate(res, size=size)
o = F.interpolate(x, size=size)
for layer in self.main_block:
o = F.leaky_relu(o, 0.2)
o = layer(o)
return o + res
def remove_weight_norm(self):
remove_parametrizations(self.res_block, "weight")
for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
remove_parametrizations(layer, "weight")
def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block)
for idx, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
self.main_block[idx] = weight_norm(layer)
+154
View File
@@ -0,0 +1,154 @@
import importlib
import re
from coqpit import Coqpit
def to_camel(text):
text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(config: Coqpit):
"""Load models directly from configuration."""
if "discriminator_model" in config and "generator_model" in config:
MyModel = importlib.import_module("TTS.vocoder.models.gan")
MyModel = getattr(MyModel, "GAN")
else:
MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower())
if config.model.lower() == "wavernn":
MyModel = getattr(MyModel, "Wavernn")
elif config.model.lower() == "gan":
MyModel = getattr(MyModel, "GAN")
elif config.model.lower() == "wavegrad":
MyModel = getattr(MyModel, "Wavegrad")
else:
try:
MyModel = getattr(MyModel, to_camel(config.model))
except ModuleNotFoundError as e:
raise ValueError(f"Model {config.model} not exist!") from e
print(" > Vocoder Model: {}".format(config.model))
return MyModel.init_from_config(config)
def setup_generator(c):
"""TODO: use config object as arguments"""
print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model))
# this is to preserve the Wavernn class name (instead of Wavernn)
if c.generator_model.lower() in "hifigan_generator":
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params)
elif c.generator_model.lower() in "melgan_generator":
model = MyModel(
in_channels=c.audio["num_mels"],
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3,
num_res_blocks=c.generator_model_params["num_res_blocks"],
)
elif c.generator_model in "melgan_fb_generator":
raise ValueError("melgan_fb_generator is now fullband_melgan_generator")
elif c.generator_model.lower() in "multiband_melgan_generator":
model = MyModel(
in_channels=c.audio["num_mels"],
out_channels=4,
proj_kernel=7,
base_channels=384,
upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3,
num_res_blocks=c.generator_model_params["num_res_blocks"],
)
elif c.generator_model.lower() in "fullband_melgan_generator":
model = MyModel(
in_channels=c.audio["num_mels"],
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3,
num_res_blocks=c.generator_model_params["num_res_blocks"],
)
elif c.generator_model.lower() in "parallel_wavegan_generator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=c.generator_model_params["num_res_blocks"],
stacks=c.generator_model_params["stacks"],
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=c.audio["num_mels"],
dropout=0.0,
bias=True,
use_weight_norm=True,
upsample_factors=c.generator_model_params["upsample_factors"],
)
elif c.generator_model.lower() in "univnet_generator":
model = MyModel(**c.generator_model_params)
else:
raise NotImplementedError(f"Model {c.generator_model} not implemented!")
return model
def setup_discriminator(c):
"""TODO: use config objekt as arguments"""
print(" > Discriminator Model: {}".format(c.discriminator_model))
if "parallel_wavegan" in c.discriminator_model:
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
else:
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower())
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
if c.discriminator_model in "hifigan_discriminator":
model = MyModel()
if c.discriminator_model in "random_window_discriminator":
model = MyModel(
cond_channels=c.audio["num_mels"],
hop_length=c.audio["hop_length"],
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"],
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"],
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"],
window_sizes=c.discriminator_model_params["window_sizes"],
)
if c.discriminator_model in "melgan_multiscale_discriminator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_sizes=(5, 3),
base_channels=c.discriminator_model_params["base_channels"],
max_channels=c.discriminator_model_params["max_channels"],
downsample_factors=c.discriminator_model_params["downsample_factors"],
)
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=c.discriminator_model_params["num_layers"],
stacks=c.discriminator_model_params["stacks"],
res_channels=64,
gate_channels=128,
skip_channels=64,
dropout=0.0,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
)
if c.discriminator_model == "parallel_wavegan_discriminator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=c.discriminator_model_params["num_layers"],
conv_channels=64,
dilation_factor=1,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True,
)
if c.discriminator_model == "univnet_discriminator":
model = MyModel()
return model
+55
View File
@@ -0,0 +1,55 @@
from coqpit import Coqpit
from TTS.model import BaseTrainerModel
# pylint: skip-file
class BaseVocoder(BaseTrainerModel):
"""Base `vocoder` class. Every new `vocoder` model must inherit this.
It defines `vocoder` specific functions on top of `Model`.
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
"""
MODEL_TYPE = "vocoder"
def __init__(self, config):
super().__init__()
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
If the config is for the model with a name like "*Args", then we assign the directly.
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars
if "model_args" in config:
self.args = self.config.model_args
# This is for backward compatibility
if "model_params" in config:
self.args = self.config.model_params
else:
self.config = config
if "model_args" in config:
self.args = self.config.model_args
# This is for backward compatibility
if "model_params" in config:
self.args = self.config.model_params
else:
raise ValueError("config must be either a *Config or *Args")
@@ -0,0 +1,33 @@
import torch
from TTS.vocoder.models.melgan_generator import MelganGenerator
class FullbandMelganGenerator(MelganGenerator):
def __init__(
self,
in_channels=80,
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=(2, 8, 2, 2),
res_kernel=3,
num_res_blocks=4,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
proj_kernel=proj_kernel,
base_channels=base_channels,
upsample_factors=upsample_factors,
res_kernel=res_kernel,
num_res_blocks=num_res_blocks,
)
@torch.no_grad()
def inference(self, cond_features):
cond_features = cond_features.to(self.layers[1].weight.device)
cond_features = torch.nn.functional.pad(
cond_features, (self.inference_padding, self.inference_padding), "replicate"
)
return self.layers(cond_features)
+374
View File
@@ -0,0 +1,374 @@
from inspect import signature
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.models import setup_discriminator, setup_generator
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results
class GAN(BaseVocoder):
def __init__(self, config: Coqpit, ap: AudioProcessor = None):
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
It also helps mixing and matching different generator and disciminator networks easily.
To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest
is handled by the `GAN` class.
Args:
config (Coqpit): Model configuration.
ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.
Examples:
Initializing the GAN model with HifiGAN generator and discriminator.
>>> from TTS.vocoder.configs import HifiganConfig
>>> config = HifiganConfig()
>>> model = GAN(config)
"""
super().__init__(config)
self.config = config
self.model_g = setup_generator(config)
self.model_d = setup_discriminator(config)
self.train_disc = False # if False, train only the generator.
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
self.ap = ap
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's forward pass.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: output of the GAN generator network.
"""
return self.model_g.forward(x)
def inference(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's inference pass.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: output of the GAN generator network.
"""
return self.model_g.inference(x)
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
network on the current pass.
Args:
batch (Dict): Batch of samples returned by the dataloader.
criterion (Dict): Criterion used to compute the losses.
optimizer_idx (int): ID of the optimizer in use on the current pass.
Raises:
ValueError: `optimizer_idx` is an unexpected value.
Returns:
Tuple[Dict, Dict]: model outputs and the computed loss values.
"""
outputs = {}
loss_dict = {}
x = batch["input"]
y = batch["waveform"]
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
if optimizer_idx == 0:
# DISCRIMINATOR optimization
# generator pass
y_hat = self.model_g(x)[:, :, : y.size(2)]
# cache for generator loss
# pylint: disable=W0201
self.y_hat_g = y_hat
self.y_hat_sub = None
self.y_sub_g = None
# PQMF formatting
if y_hat.shape[1] > 1:
self.y_hat_sub = y_hat
y_hat = self.model_g.pqmf_synthesis(y_hat)
self.y_hat_g = y_hat # save for generator loss
self.y_sub_g = self.model_g.pqmf_analysis(y)
scores_fake, feats_fake, feats_real = None, None, None
if self.train_disc:
# use different samples for G and D trainings
if self.config.diff_samples_for_G_and_D:
x_d = batch["input_disc"]
y_d = batch["waveform_disc"]
# use a different sample than generator
with torch.no_grad():
y_hat = self.model_g(x_d)
# PQMF formatting
if y_hat.shape[1] > 1:
y_hat = self.model_g.pqmf_synthesis(y_hat)
else:
# use the same samples as generator
x_d = x.clone()
y_d = y.clone()
y_hat = self.y_hat_g
# run D with or without cond. features
if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(y_hat.detach().clone(), x_d)
D_out_real = self.model_d(y_d, x_d)
else:
D_out_fake = self.model_d(y_hat.detach())
D_out_real = self.model_d(y_d)
# format D outputs
if isinstance(D_out_fake, tuple):
# self.model_d returns scores and features
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
scores_real, feats_real = None, None
else:
scores_real, feats_real = D_out_real
else:
# model D returns only scores
scores_fake = D_out_fake
scores_real = D_out_real
# compute losses
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
outputs = {"model_outputs": y_hat}
if optimizer_idx == 1:
# GENERATOR loss
scores_fake, feats_fake, feats_real = None, None, None
if self.train_disc:
if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(self.y_hat_g, x)
else:
D_out_fake = self.model_d(self.y_hat_g)
D_out_real = None
if self.config.use_feat_match_loss:
with torch.no_grad():
D_out_real = self.model_d(y)
# format D outputs
if isinstance(D_out_fake, tuple):
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
feats_real = None
else:
_, feats_real = D_out_real
else:
scores_fake = D_out_fake
feats_fake, feats_real = None, None
# compute losses
loss_dict = criterion[optimizer_idx](
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
)
outputs = {"model_outputs": self.y_hat_g}
return outputs, loss_dict
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
"""Logging shared by the training and evaluation.
Args:
name (str): Name of the run. `train` or `eval`,
ap (AudioProcessor): Audio processor used in training.
batch (Dict): Batch used in the last train/eval step.
outputs (Dict): Model outputs from the last train/eval step.
Returns:
Tuple[Dict, Dict]: log figures and audio samples.
"""
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
y = batch["waveform"]
figures = plot_results(y_hat, y, ap, name)
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
audios = {f"{name}/audio": sample_voice}
return figures, audios
def train_log(
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for training."""
figures, audios = self._log("eval", self.ap, batch, outputs)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Call `train_step()` with `no_grad()`"""
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for evaluation."""
figures, audios = self._log("eval", self.ap, batch, outputs)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint(
self,
config: Coqpit,
checkpoint_path: str,
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
cache: bool = False,
) -> None:
"""Load a GAN checkpoint and initialize model parameters.
Args:
config (Coqpit): Model config.
checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
# band-aid for older than v0.0.15 GAN models
if "model_disc" in state:
self.model_g.load_checkpoint(config, checkpoint_path, eval)
else:
self.load_state_dict(state["model"])
if eval:
self.model_d = None
if hasattr(self.model_g, "remove_weight_norm"):
self.model_g.remove_weight_norm()
def on_train_step_start(self, trainer) -> None:
"""Enable the discriminator training based on `steps_to_start_discriminator`
Args:
trainer (Trainer): Trainer object.
"""
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
Returns:
List: optimizers.
"""
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
)
optimizer2 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
)
return [optimizer2, optimizer1]
def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.
Returns:
List: learning rates for each optimizer.
"""
return [self.config.lr_disc, self.config.lr_gen]
def get_scheduler(self, optimizer) -> List:
"""Set the schedulers for each optimizer.
Args:
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
Returns:
List: Schedulers, one for each optimizer.
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler2, scheduler1]
@staticmethod
def format_batch(batch: List) -> Dict:
"""Format the batch for training.
Args:
batch (List): Batch out of the dataloader.
Returns:
Dict: formatted model inputs.
"""
if isinstance(batch[0], list):
x_G, y_G = batch[0]
x_D, y_D = batch[1]
return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D}
x, y = batch
return {"input": x, "waveform": y}
def get_data_loader( # pylint: disable=no-self-use, unused-argument
self,
config: Coqpit,
assets: Dict,
is_eval: True,
samples: List,
verbose: bool,
num_gpus: int,
rank: int = None, # pylint: disable=unused-argument
):
"""Initiate and return the GAN dataloader.
Args:
config (Coqpit): Model config.
ap (AudioProcessor): Audio processor.
is_eval (True): Set the dataloader for evaluation if true.
samples (List): Data samples.
verbose (bool): Log information if true.
num_gpus (int): Number of GPUs in use.
rank (int): Rank of the current GPU. Defaults to None.
Returns:
DataLoader: Torch dataloader.
"""
dataset = GANDataset(
ap=self.ap,
items=samples,
seq_len=config.seq_len,
hop_len=self.ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
is_training=not is_eval,
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=1 if is_eval else config.batch_size,
shuffle=num_gpus == 0,
drop_last=False,
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
def get_criterion(self):
"""Return criterions for the optimizers"""
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]
@staticmethod
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
ap = AudioProcessor.init_from_config(config, verbose=verbose)
return GAN(config, ap=ap)
+217
View File
@@ -0,0 +1,217 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import torch
from torch import nn
from torch.nn import functional as F
LRELU_SLOPE = 0.1
class DiscriminatorP(torch.nn.Module):
"""HiFiGAN Periodic Discriminator
Takes every Pth value from the input waveform and applied a stack of convoluations.
Note:
if `period` is 2
`waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat`
Args:
x (Tensor): input waveform.
Returns:
[Tensor]: discriminator scores per sample in the batch.
[List[Tensor]]: list of features from each convolutional layer.
Shapes:
x: [B, 1, T]
"""
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super().__init__()
self.period = period
get_padding = lambda k, d: int((k * d - d) / 2)
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
[Tensor]: discriminator scores per sample in the batch.
[List[Tensor]]: list of features from each convolutional layer.
Shapes:
x: [B, 1, T]
"""
feat = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat
class MultiPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN Multi-Period Discriminator (MPD)
Wrapper for the `PeriodDiscriminator` to apply it in different periods.
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
]
)
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
[List[Tensor]]: list of scores from each discriminator.
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
Shapes:
x: [B, 1, T]
"""
scores = []
feats = []
for _, d in enumerate(self.discriminators):
score, feat = d(x)
scores.append(score)
feats.append(feat)
return scores, feats
class DiscriminatorS(torch.nn.Module):
"""HiFiGAN Scale Discriminator.
It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper.
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
Tensor: discriminator scores.
List[Tensor]: list of features from the convolutiona layers.
"""
feat = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat
class MultiScaleDiscriminator(torch.nn.Module):
"""HiFiGAN Multi-Scale Discriminator.
It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper.
"""
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
]
)
self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)])
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores = []
feats = []
for i, d in enumerate(self.discriminators):
if i != 0:
x = self.meanpools[i - 1](x)
score, feat = d(x)
scores.append(score)
feats.append(feat)
return scores, feats
class HifiganDiscriminator(nn.Module):
"""HiFiGAN discriminator wrapping MPD and MSD."""
def __init__(self):
super().__init__()
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiScaleDiscriminator()
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
scores_, feats_ = self.msd(x)
return scores + scores_, feats + feats_
+301
View File
@@ -0,0 +1,301 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1
def get_padding(k, d):
return int((k * d - d) / 2)
class ResBlock1(torch.nn.Module):
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
Network::
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|--------------------------------------------------------------------------------------------------|
Args:
channels (int): number of hidden channels for the convolutional layers.
kernel_size (int): size of the convolution filter in each layer.
dilations (list): list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
]
)
def forward(self, x):
"""
Args:
x (Tensor): input tensor.
Returns:
Tensor: output tensor.
Shapes:
x: [B, C, T]
"""
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_parametrizations(l, "weight")
class ResBlock2(torch.nn.Module):
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
Network::
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|---------------------------------------------------|
Args:
channels (int): number of hidden channels for the convolutional layers.
kernel_size (int): size of the convolution filter in each layer.
dilations (list): list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_parametrizations(l, "weight")
class HifiganGenerator(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
resblock_type,
resblock_dilation_sizes,
resblock_kernel_sizes,
upsample_kernel_sizes,
upsample_initial_channel,
upsample_factors,
inference_padding=5,
cond_channels=0,
conv_pre_weight_norm=True,
conv_post_weight_norm=True,
conv_post_bias=True,
):
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
Network:
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
.. -> zI ---|
resblockN_kNx1 -> zN ---'
Args:
in_channels (int): number of input tensor channels.
out_channels (int): number of output tensor channels.
resblock_type (str): type of the `ResBlock`. '1' or '2'.
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
for each consecutive upsampling layer.
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
"""
super().__init__()
self.inference_padding = inference_padding
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_factors)
# initial upsampling layers
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# MRF blocks
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
if not conv_pre_weight_norm:
remove_parametrizations(self.conv_pre, "weight")
if not conv_post_weight_norm:
remove_parametrizations(self.conv_post, "weight")
def forward(self, x, g=None):
"""
Args:
x (Tensor): feature input tensor.
g (Tensor): global conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
o = self.conv_pre(x)
if hasattr(self, "cond_layer"):
o = o + self.cond_layer(g)
for i in range(self.num_upsamples):
o = F.leaky_relu(o, LRELU_SLOPE)
o = self.ups[i](o)
z_sum = None
for j in range(self.num_kernels):
if z_sum is None:
z_sum = self.resblocks[i * self.num_kernels + j](o)
else:
z_sum += self.resblocks[i * self.num_kernels + j](o)
o = z_sum / self.num_kernels
o = F.leaky_relu(o)
o = self.conv_post(o)
o = torch.tanh(o)
return o
@torch.no_grad()
def inference(self, c):
"""
Args:
x (Tensor): conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.forward(c)
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()
remove_parametrizations(self.conv_pre, "weight")
remove_parametrizations(self.conv_post, "weight")
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()
@@ -0,0 +1,84 @@
import numpy as np
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
class MelganDiscriminator(nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_sizes=(5, 3),
base_channels=16,
max_channels=1024,
downsample_factors=(4, 4, 4, 4),
groups_denominator=4,
):
super().__init__()
self.layers = nn.ModuleList()
layer_kernel_size = np.prod(kernel_sizes)
layer_padding = (layer_kernel_size - 1) // 2
# initial layer
self.layers += [
nn.Sequential(
nn.ReflectionPad1d(layer_padding),
weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)),
nn.LeakyReLU(0.2, inplace=True),
)
]
# downsampling layers
layer_in_channels = base_channels
for downsample_factor in downsample_factors:
layer_out_channels = min(layer_in_channels * downsample_factor, max_channels)
layer_kernel_size = downsample_factor * 10 + 1
layer_padding = (layer_kernel_size - 1) // 2
layer_groups = layer_in_channels // groups_denominator
self.layers += [
nn.Sequential(
weight_norm(
nn.Conv1d(
layer_in_channels,
layer_out_channels,
kernel_size=layer_kernel_size,
stride=downsample_factor,
padding=layer_padding,
groups=layer_groups,
)
),
nn.LeakyReLU(0.2, inplace=True),
)
]
layer_in_channels = layer_out_channels
# last 2 layers
layer_padding1 = (kernel_sizes[0] - 1) // 2
layer_padding2 = (kernel_sizes[1] - 1) // 2
self.layers += [
nn.Sequential(
weight_norm(
nn.Conv1d(
layer_out_channels,
layer_out_channels,
kernel_size=kernel_sizes[0],
stride=1,
padding=layer_padding1,
)
),
nn.LeakyReLU(0.2, inplace=True),
),
weight_norm(
nn.Conv1d(
layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2
)
),
]
def forward(self, x):
feats = []
for layer in self.layers:
x = layer(x)
feats.append(x)
return x, feats
+95
View File
@@ -0,0 +1,95 @@
import torch
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack
class MelganGenerator(nn.Module):
def __init__(
self,
in_channels=80,
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=(8, 8, 2, 2),
res_kernel=3,
num_res_blocks=3,
):
super().__init__()
# assert model parameters
assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number."
# setup additional model parameters
base_padding = (proj_kernel - 1) // 2
act_slope = 0.2
self.inference_padding = 2
# initial layer
layers = []
layers += [
nn.ReflectionPad1d(base_padding),
weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)),
]
# upsampling layers and residual stacks
for idx, upsample_factor in enumerate(upsample_factors):
layer_in_channels = base_channels // (2**idx)
layer_out_channels = base_channels // (2 ** (idx + 1))
layer_filter_size = upsample_factor * 2
layer_stride = upsample_factor
layer_output_padding = upsample_factor % 2
layer_padding = upsample_factor // 2 + layer_output_padding
layers += [
nn.LeakyReLU(act_slope),
weight_norm(
nn.ConvTranspose1d(
layer_in_channels,
layer_out_channels,
layer_filter_size,
stride=layer_stride,
padding=layer_padding,
output_padding=layer_output_padding,
bias=True,
)
),
ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel),
]
layers += [nn.LeakyReLU(act_slope)]
# final layer
layers += [
nn.ReflectionPad1d(base_padding),
weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)),
nn.Tanh(),
]
self.layers = nn.Sequential(*layers)
def forward(self, c):
return self.layers(c)
def inference(self, c):
c = c.to(self.layers[1].weight.device)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.layers(c)
def remove_weight_norm(self):
for _, layer in enumerate(self.layers):
if len(layer.state_dict()) != 0:
try:
nn.utils.parametrize.remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()
@@ -0,0 +1,50 @@
from torch import nn
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
class MelganMultiscaleDiscriminator(nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
num_scales=3,
kernel_sizes=(5, 3),
base_channels=16,
max_channels=1024,
downsample_factors=(4, 4, 4),
pooling_kernel_size=4,
pooling_stride=2,
pooling_padding=2,
groups_denominator=4,
):
super().__init__()
self.discriminators = nn.ModuleList(
[
MelganDiscriminator(
in_channels=in_channels,
out_channels=out_channels,
kernel_sizes=kernel_sizes,
base_channels=base_channels,
max_channels=max_channels,
downsample_factors=downsample_factors,
groups_denominator=groups_denominator,
)
for _ in range(num_scales)
]
)
self.pooling = nn.AvgPool1d(
kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False
)
def forward(self, x):
scores = []
feats = []
for disc in self.discriminators:
score, feat = disc(x)
scores.append(score)
feats.append(feat)
x = self.pooling(x)
return scores, feats
@@ -0,0 +1,41 @@
import torch
from TTS.vocoder.layers.pqmf import PQMF
from TTS.vocoder.models.melgan_generator import MelganGenerator
class MultibandMelganGenerator(MelganGenerator):
def __init__(
self,
in_channels=80,
out_channels=4,
proj_kernel=7,
base_channels=384,
upsample_factors=(2, 8, 2, 2),
res_kernel=3,
num_res_blocks=3,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
proj_kernel=proj_kernel,
base_channels=base_channels,
upsample_factors=upsample_factors,
res_kernel=res_kernel,
num_res_blocks=num_res_blocks,
)
self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
def pqmf_analysis(self, x):
return self.pqmf_layer.analysis(x)
def pqmf_synthesis(self, x):
return self.pqmf_layer.synthesis(x)
@torch.no_grad()
def inference(self, cond_features):
cond_features = cond_features.to(self.layers[1].weight.device)
cond_features = torch.nn.functional.pad(
cond_features, (self.inference_padding, self.inference_padding), "replicate"
)
return self.pqmf_synthesis(self.layers(cond_features))
@@ -0,0 +1,187 @@
import math
import torch
from torch import nn
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
class ParallelWaveganDiscriminator(nn.Module):
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
It classifies each audio window real/fake and returns a sequence
of predictions.
It is a stack of convolutional blocks with dilation.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=10,
conv_channels=64,
dilation_factor=1,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True,
):
super().__init__()
assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size."
assert dilation_factor > 0, " [!] dilation factor must be > 0."
self.conv_layers = nn.ModuleList()
conv_in_channels = in_channels
for i in range(num_layers - 1):
if i == 0:
dilation = 1
else:
dilation = i if dilation_factor == 1 else dilation_factor**i
conv_in_channels = conv_channels
padding = (kernel_size - 1) // 2 * dilation
conv_layer = [
nn.Conv1d(
conv_in_channels,
conv_channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
bias=bias,
),
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
]
self.conv_layers += conv_layer
padding = (kernel_size - 1) // 2
last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
self.conv_layers += [last_conv_layer]
self.apply_weight_norm()
def forward(self, x):
"""
x : (B, 1, T).
Returns:
Tensor: (B, 1, T)
"""
for f in self.conv_layers:
x = f(x)
return x
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
class ResidualParallelWaveganDiscriminator(nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
dropout=0.0,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
):
super().__init__()
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
self.in_channels = in_channels
self.out_channels = out_channels
self.num_layers = num_layers
self.stacks = stacks
self.kernel_size = kernel_size
self.res_factor = math.sqrt(1.0 / num_layers)
# check the number of num_layers and stacks
assert num_layers % stacks == 0
layers_per_stack = num_layers // stacks
# define first convolution
self.first_conv = nn.Sequential(
nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True),
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
)
# define residual blocks
self.conv_layers = nn.ModuleList()
for layer in range(num_layers):
dilation = 2 ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
res_channels=res_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=-1,
dilation=dilation,
dropout=dropout,
bias=bias,
use_causal_conv=False,
)
self.conv_layers += [conv]
# define output layers
self.last_conv_layers = nn.ModuleList(
[
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True),
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True),
]
)
# apply weight norm
self.apply_weight_norm()
def forward(self, x):
"""
x: (B, 1, T).
"""
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, None)
skips += h
skips *= self.res_factor
# apply final layers
x = skips
for f in self.last_conv_layers:
x = f(x)
return x
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
print(f"Weight norm is removed from {m}.")
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
@@ -0,0 +1,164 @@
import math
import numpy as np
import torch
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
class ParallelWaveganGenerator(torch.nn.Module):
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
It is similar to WaveNet with no causal convolution.
It is conditioned on an aux feature (spectrogram) to generate
an output waveform from an input noise.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
bias=True,
use_weight_norm=True,
upsample_factors=[4, 4, 4, 4],
inference_padding=2,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.num_res_blocks = num_res_blocks
self.stacks = stacks
self.kernel_size = kernel_size
self.upsample_factors = upsample_factors
self.upsample_scale = np.prod(upsample_factors)
self.inference_padding = inference_padding
self.use_weight_norm = use_weight_norm
# check the number of layers and stacks
assert num_res_blocks % stacks == 0
layers_per_stack = num_res_blocks // stacks
# define first convolution
self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True)
# define conv + upsampling network
self.upsample_net = ConvUpsample(upsample_factors=upsample_factors)
# define residual blocks
self.conv_layers = torch.nn.ModuleList()
for layer in range(num_res_blocks):
dilation = 2 ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
res_channels=res_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias,
)
self.conv_layers += [conv]
# define output layers
self.last_conv_layers = torch.nn.ModuleList(
[
torch.nn.ReLU(inplace=True),
torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True),
torch.nn.ReLU(inplace=True),
torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True),
]
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, c):
"""
c: (B, C ,T').
o: Output tensor (B, out_channels, T)
"""
# random noise
x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale])
x = x.to(self.first_conv.bias.device)
# perform upsampling
if c is not None and self.upsample_net is not None:
c = self.upsample_net(c)
assert (
c.shape[-1] == x.shape[-1]
), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"
# encode to hidden representation
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, c)
skips += h
skips *= math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
x = skips
for f in self.last_conv_layers:
x = f(x)
return x
@torch.no_grad()
def inference(self, c):
c = c.to(self.first_conv.weight.device)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.forward(c)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
if self.use_weight_norm:
self.remove_weight_norm()
@@ -0,0 +1,203 @@
import numpy as np
from torch import nn
class GBlock(nn.Module):
def __init__(self, in_channels, cond_channels, downsample_factor):
super().__init__()
self.in_channels = in_channels
self.cond_channels = cond_channels
self.downsample_factor = downsample_factor
self.start = nn.Sequential(
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
nn.ReLU(),
nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1),
)
self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1)
self.end = nn.Sequential(
nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2)
)
self.residual = nn.Sequential(
nn.Conv1d(in_channels, in_channels * 2, kernel_size=1),
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
)
def forward(self, inputs, conditions):
outputs = self.start(inputs) + self.lc_conv1d(conditions)
outputs = self.end(outputs)
residual_outputs = self.residual(inputs)
outputs = outputs + residual_outputs
return outputs
class DBlock(nn.Module):
def __init__(self, in_channels, out_channels, downsample_factor):
super().__init__()
self.in_channels = in_channels
self.downsample_factor = downsample_factor
self.out_channels = out_channels
self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor)
self.layers = nn.Sequential(
nn.ReLU(),
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2),
)
self.residual = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=1),
)
def forward(self, inputs):
if self.downsample_factor > 1:
outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs))
else:
outputs = self.layers(inputs) + self.residual(inputs)
return outputs
class ConditionalDiscriminator(nn.Module):
def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)):
super().__init__()
assert len(downsample_factors) == len(out_channels) + 1
self.in_channels = in_channels
self.cond_channels = cond_channels
self.downsample_factors = downsample_factors
self.out_channels = out_channels
self.pre_cond_layers = nn.ModuleList()
self.post_cond_layers = nn.ModuleList()
# layers before condition features
self.pre_cond_layers += [DBlock(in_channels, 64, 1)]
in_channels = 64
for i, channel in enumerate(out_channels):
self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i]))
in_channels = channel
# condition block
self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1])
# layers after condition block
self.post_cond_layers += [
DBlock(in_channels * 2, in_channels * 2, 1),
DBlock(in_channels * 2, in_channels * 2, 1),
nn.AdaptiveAvgPool1d(1),
nn.Conv1d(in_channels * 2, 1, kernel_size=1),
]
def forward(self, inputs, conditions):
batch_size = inputs.size()[0]
outputs = inputs.view(batch_size, self.in_channels, -1)
for layer in self.pre_cond_layers:
outputs = layer(outputs)
outputs = self.cond_block(outputs, conditions)
for layer in self.post_cond_layers:
outputs = layer(outputs)
return outputs
class UnconditionalDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)):
super().__init__()
self.downsample_factors = downsample_factors
self.in_channels = in_channels
self.downsample_factors = downsample_factors
self.out_channels = out_channels
self.layers = nn.ModuleList()
self.layers += [DBlock(self.in_channels, base_channels, 1)]
in_channels = base_channels
for i, factor in enumerate(downsample_factors):
self.layers.append(DBlock(in_channels, out_channels[i], factor))
in_channels *= 2
self.layers += [
DBlock(in_channels, in_channels, 1),
DBlock(in_channels, in_channels, 1),
nn.AdaptiveAvgPool1d(1),
nn.Conv1d(in_channels, 1, kernel_size=1),
]
def forward(self, inputs):
batch_size = inputs.size()[0]
outputs = inputs.view(batch_size, self.in_channels, -1)
for layer in self.layers:
outputs = layer(outputs)
return outputs
class RandomWindowDiscriminator(nn.Module):
"""Random Window Discriminator as described in
http://arxiv.org/abs/1909.11646"""
def __init__(
self,
cond_channels,
hop_length,
uncond_disc_donwsample_factors=(8, 4),
cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)),
cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)),
window_sizes=(512, 1024, 2048, 4096, 8192),
):
super().__init__()
self.cond_channels = cond_channels
self.window_sizes = window_sizes
self.hop_length = hop_length
self.base_window_size = self.hop_length * 2
self.ks = [ws // self.base_window_size for ws in window_sizes]
# check arguments
assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes)
for ws in window_sizes:
assert ws % hop_length == 0
for idx, cf in enumerate(cond_disc_downsample_factors):
assert np.prod(cf) == hop_length // self.ks[idx]
# define layers
self.unconditional_discriminators = nn.ModuleList([])
for k in self.ks:
layer = UnconditionalDiscriminator(
in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors
)
self.unconditional_discriminators.append(layer)
self.conditional_discriminators = nn.ModuleList([])
for idx, k in enumerate(self.ks):
layer = ConditionalDiscriminator(
in_channels=k,
cond_channels=cond_channels,
downsample_factors=cond_disc_downsample_factors[idx],
out_channels=cond_disc_out_channels[idx],
)
self.conditional_discriminators.append(layer)
def forward(self, x, c):
scores = []
feats = []
# unconditional pass
for window_size, layer in zip(self.window_sizes, self.unconditional_discriminators):
index = np.random.randint(x.shape[-1] - window_size)
score = layer(x[:, :, index : index + window_size])
scores.append(score)
# conditional pass
for window_size, layer in zip(self.window_sizes, self.conditional_discriminators):
frame_size = window_size // self.hop_length
lc_index = np.random.randint(c.shape[-1] - frame_size)
sample_index = lc_index * self.hop_length
x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length]
c_sub = c[:, :, lc_index : lc_index + frame_size]
score = layer(x_sub, c_sub)
scores.append(score)
return scores, feats
@@ -0,0 +1,95 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
LRELU_SLOPE = 0.1
class SpecDiscriminator(nn.Module):
"""docstring for Discriminator."""
def __init__(self, fft_size=1024, hop_length=120, win_length=600, use_spectral_norm=False):
super().__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.fft_size = fft_size
self.hop_length = hop_length
self.win_length = win_length
self.stft = TorchSTFT(fft_size, hop_length, win_length)
self.discriminators = nn.ModuleList(
[
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
]
)
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
def forward(self, y):
fmap = []
with torch.no_grad():
y = y.squeeze(1)
y = self.stft(y)
y = y.unsqueeze(1)
for _, d in enumerate(self.discriminators):
y = d(y)
y = F.leaky_relu(y, LRELU_SLOPE)
fmap.append(y)
y = self.out(y)
fmap.append(y)
return torch.flatten(y, 1, -1), fmap
class MultiResSpecDiscriminator(torch.nn.Module):
def __init__( # pylint: disable=dangerous-default-value
self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann_window"
):
super().__init__()
self.discriminators = nn.ModuleList(
[
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window),
]
)
def forward(self, x):
scores = []
feats = []
for d in self.discriminators:
score, feat = d(x)
scores.append(score)
feats.append(feat)
return scores, feats
class UnivnetDiscriminator(nn.Module):
"""Univnet discriminator wrapping MPD and MSD."""
def __init__(self):
super().__init__()
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiResSpecDiscriminator()
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
scores_, feats_ = self.msd(x)
return scores + scores_, feats + feats_
+157
View File
@@ -0,0 +1,157 @@
from typing import List
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils import parametrize
from TTS.vocoder.layers.lvc_block import LVCBlock
LRELU_SLOPE = 0.1
class UnivnetGenerator(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
cond_channels: int,
upsample_factors: List[int],
lvc_layers_each_block: int,
lvc_kernel_size: int,
kpnet_hidden_channels: int,
kpnet_conv_size: int,
dropout: float,
use_weight_norm=True,
):
"""Univnet Generator network.
Paper: https://arxiv.org/pdf/2106.07889.pdf
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of channels of the output tensor.
hidden_channels (int): Number of hidden network channels.
cond_channels (int): Number of channels of the conditioning tensors.
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
lvc_layers_each_block (int): Number of LVC layers in each block.
lvc_kernel_size (int): Kernel size of the LVC layers.
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
kpnet_conv_size (int): Number of convolution channels in the key-point network.
dropout (float): Dropout rate.
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.cond_channels = cond_channels
self.upsample_scale = np.prod(upsample_factors)
self.lvc_block_nums = len(upsample_factors)
# define first convolution
self.first_conv = torch.nn.Conv1d(
in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
)
# define residual blocks
self.lvc_blocks = torch.nn.ModuleList()
cond_hop_length = 1
for n in range(self.lvc_block_nums):
cond_hop_length = cond_hop_length * upsample_factors[n]
lvcb = LVCBlock(
in_channels=hidden_channels,
cond_channels=cond_channels,
upsample_ratio=upsample_factors[n],
conv_layers=lvc_layers_each_block,
conv_kernel_size=lvc_kernel_size,
cond_hop_length=cond_hop_length,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=dropout,
)
self.lvc_blocks += [lvcb]
# define output layers
self.last_conv_layers = torch.nn.ModuleList(
[
torch.nn.Conv1d(
hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
),
]
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, c):
"""Calculate forward propagation.
Args:
c (Tensor): Local conditioning auxiliary features (B, C ,T').
Returns:
Tensor: Output tensor (B, out_channels, T)
"""
# random noise
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
x = x.to(self.first_conv.bias.device)
x = self.first_conv(x)
for n in range(self.lvc_block_nums):
x = self.lvc_blocks[n](x, c)
# apply final layers
for f in self.last_conv_layers:
x = F.leaky_relu(x, LRELU_SLOPE)
x = f(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
parametrize.remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
"""Return receptive field size."""
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
@torch.no_grad()
def inference(self, c):
"""Perform inference.
Args:
c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`.
Returns:
Tensor: Output tensor (T, out_channels)
"""
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
x = x.to(self.first_conv.bias.device)
c = c.to(next(self.parameters()))
return self.forward(c)
+345
View File
@@ -0,0 +1,345 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results
@dataclass
class WavegradArgs(Coqpit):
in_channels: int = 80
out_channels: int = 1
use_weight_norm: bool = False
y_conv_channels: int = 32
x_conv_channels: int = 768
dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512])
ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
upsample_dilations: List[List[int]] = field(
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]]
)
class Wavegrad(BaseVocoder):
"""🐸 🌊 WaveGrad 🌊 model.
Paper - https://arxiv.org/abs/2009.00713
Examples:
Initializing the model.
>>> from TTS.vocoder.configs import WavegradConfig
>>> config = WavegradConfig()
>>> model = Wavegrad(config)
Paper Abstract:
This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the
data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts
from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned
on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting
the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in
terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations.
Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive
baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations.
Audio samples are available at this https URL.
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
super().__init__(config)
self.config = config
self.use_weight_norm = config.model_params.use_weight_norm
self.hop_len = np.prod(config.model_params.upsample_factors)
self.noise_level = None
self.num_steps = None
self.beta = None
self.alpha = None
self.alpha_hat = None
self.c1 = None
self.c2 = None
self.sigma = None
# dblocks
self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2)
self.dblocks = nn.ModuleList([])
ic = config.model_params.y_conv_channels
for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)):
self.dblocks.append(DBlock(ic, oc, df))
ic = oc
# film
self.film = nn.ModuleList([])
ic = config.model_params.y_conv_channels
for oc in reversed(config.model_params.ublock_out_channels):
self.film.append(FiLM(ic, oc))
ic = oc
# ublocksn
self.ublocks = nn.ModuleList([])
ic = config.model_params.x_conv_channels
for oc, uf, ud in zip(
config.model_params.ublock_out_channels,
config.model_params.upsample_factors,
config.model_params.upsample_dilations,
):
self.ublocks.append(UBlock(ic, oc, uf, ud))
ic = oc
self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1)
self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1)
if config.model_params.use_weight_norm:
self.apply_weight_norm()
def forward(self, x, spectrogram, noise_scale):
shift_and_scale = []
x = self.y_conv(x)
shift_and_scale.append(self.film[0](x, noise_scale))
for film, layer in zip(self.film[1:], self.dblocks):
x = layer(x)
shift_and_scale.append(film(x, noise_scale))
x = self.x_conv(spectrogram)
for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)):
x = layer(x, film_shift, film_scale)
x = self.out_conv(x)
return x
def load_noise_schedule(self, path):
beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg
self.compute_noise_level(beta)
@torch.no_grad()
def inference(self, x, y_n=None):
"""
Shapes:
x: :math:`[B, C , T]`
y_n: :math:`[B, 1, T]`
"""
if y_n is None:
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
else:
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
y_n = y_n.type_as(x)
sqrt_alpha_hat = self.noise_level.to(x)
for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
if n > 0:
z = torch.randn_like(y_n)
y_n += self.sigma[n - 1] * z
y_n.clamp_(-1.0, 1.0)
return y_n
def compute_y_n(self, y_0):
"""Compute noisy audio based on noise schedule"""
self.noise_level = self.noise_level.to(y_0)
if len(y_0.shape) == 3:
y_0 = y_0.squeeze(1)
s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]])
l_a, l_b = self.noise_level[s], self.noise_level[s + 1]
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
noise_scale = noise_scale.unsqueeze(1)
noise = torch.randn_like(y_0)
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
def compute_noise_level(self, beta):
"""Compute noise schedule parameters"""
self.num_steps = len(beta)
alpha = 1 - beta
alpha_hat = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0)
noise_level = alpha_hat**0.5
# pylint: disable=not-callable
self.beta = torch.tensor(beta.astype(np.float32))
self.alpha = torch.tensor(alpha.astype(np.float32))
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
self.noise_level = torch.tensor(noise_level.astype(np.float32))
self.c1 = 1 / self.alpha**0.5
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5
def remove_weight_norm(self):
for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0:
try:
remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0:
try:
remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0:
try:
remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
remove_parametrizations(self.x_conv, "weight")
remove_parametrizations(self.out_conv, "weight")
remove_parametrizations(self.y_conv, "weight")
def apply_weight_norm(self):
for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
self.x_conv = weight_norm(self.x_conv)
self.out_conv = weight_norm(self.out_conv)
self.y_conv = weight_norm(self.y_conv)
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
if self.config.model_params.use_weight_norm:
self.remove_weight_norm()
betas = np.linspace(
config["test_noise_schedule"]["min_val"],
config["test_noise_schedule"]["max_val"],
config["test_noise_schedule"]["num_steps"],
)
self.compute_noise_level(betas)
else:
betas = np.linspace(
config["train_noise_schedule"]["min_val"],
config["train_noise_schedule"]["max_val"],
config["train_noise_schedule"]["num_steps"],
)
self.compute_noise_level(betas)
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
# format data
x = batch["input"]
y = batch["waveform"]
# set noise scale
noise, x_noisy, noise_scale = self.compute_y_n(y)
# forward pass
noise_hat = self.forward(x_noisy, x, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
return {"model_output": noise_hat}, {"loss": loss}
def train_log( # pylint: disable=no-self-use
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
pass
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)
def eval_log( # pylint: disable=no-self-use
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> None:
pass
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
# setup noise schedule and inference
ap = assets["audio_processor"]
noise_schedule = self.config["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
samples = test_loader.dataset.load_test_samples(1)
for sample in samples:
x = sample[0]
x = x[None, :, :].to(next(self.parameters()).device)
y = sample[1]
y = y[None, :]
# compute voice
y_pred = self.inference(x)
# compute spectrograms
figures = plot_results(y_pred, y, ap, "test")
# Sample audio
sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
return figures, {"test/audio": sample_voice}
def get_optimizer(self):
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
def get_scheduler(self, optimizer):
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
@staticmethod
def get_criterion():
return torch.nn.L1Loss()
@staticmethod
def format_batch(batch: Dict) -> Dict:
# return a whole audio segment
m, y = batch[0], batch[1]
y = y.unsqueeze(1)
return {"input": m, "waveform": y}
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
ap = assets["audio_processor"]
dataset = WaveGradDataset(
ap=ap,
items=samples,
seq_len=self.config.seq_len,
hop_len=ap.hop_length,
pad_short=self.config.pad_short,
conv_pad=self.config.conv_pad,
is_training=not is_eval,
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=num_gpus <= 1,
drop_last=False,
sampler=sampler,
num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers,
pin_memory=False,
)
return loader
def on_epoch_start(self, trainer): # pylint: disable=unused-argument
noise_schedule = self.config["train_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
@staticmethod
def init_from_config(config: "WavegradConfig"):
return Wavegrad(config)
+646
View File
@@ -0,0 +1,646 @@
import sys
import time
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
def stream(string, variables):
sys.stdout.write(f"\r{string}" % variables)
# pylint: disable=abstract-method
# relates https://github.com/pytorch/pytorch/issues/42305
class ResBlock(nn.Module):
def __init__(self, dims):
super().__init__()
self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
self.batch_norm1 = nn.BatchNorm1d(dims)
self.batch_norm2 = nn.BatchNorm1d(dims)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.batch_norm1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.batch_norm2(x)
return x + residual
class MelResNet(nn.Module):
def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
super().__init__()
k_size = pad * 2 + 1
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
self.batch_norm = nn.BatchNorm1d(compute_dims)
self.layers = nn.ModuleList()
for _ in range(num_res_blocks):
self.layers.append(ResBlock(compute_dims))
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
def forward(self, x):
x = self.conv_in(x)
x = self.batch_norm(x)
x = F.relu(x)
for f in self.layers:
x = f(x)
x = self.conv_out(x)
return x
class Stretch2d(nn.Module):
def __init__(self, x_scale, y_scale):
super().__init__()
self.x_scale = x_scale
self.y_scale = y_scale
def forward(self, x):
b, c, h, w = x.size()
x = x.unsqueeze(-1).unsqueeze(3)
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
return x.view(b, c, h * self.y_scale, w * self.x_scale)
class UpsampleNetwork(nn.Module):
def __init__(
self,
feat_dims,
upsample_scales,
compute_dims,
num_res_blocks,
res_out_dims,
pad,
use_aux_net,
):
super().__init__()
self.total_scale = np.cumproduct(upsample_scales)[-1]
self.indent = pad * self.total_scale
self.use_aux_net = use_aux_net
if use_aux_net:
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
self.resnet_stretch = Stretch2d(self.total_scale, 1)
self.up_layers = nn.ModuleList()
for scale in upsample_scales:
k_size = (1, scale * 2 + 1)
padding = (0, scale)
stretch = Stretch2d(scale, 1)
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
conv.weight.data.fill_(1.0 / k_size[1])
self.up_layers.append(stretch)
self.up_layers.append(conv)
def forward(self, m):
if self.use_aux_net:
aux = self.resnet(m).unsqueeze(1)
aux = self.resnet_stretch(aux)
aux = aux.squeeze(1)
aux = aux.transpose(1, 2)
else:
aux = None
m = m.unsqueeze(1)
for f in self.up_layers:
m = f(m)
m = m.squeeze(1)[:, :, self.indent : -self.indent]
return m.transpose(1, 2), aux
class Upsample(nn.Module):
def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net):
super().__init__()
self.scale = scale
self.pad = pad
self.indent = pad * scale
self.use_aux_net = use_aux_net
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
def forward(self, m):
if self.use_aux_net:
aux = self.resnet(m)
aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True)
aux = aux.transpose(1, 2)
else:
aux = None
m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True)
m = m[:, :, self.indent : -self.indent]
m = m * 0.045 # empirically found
return m.transpose(1, 2), aux
@dataclass
class WavernnArgs(Coqpit):
"""🐸 WaveRNN model arguments.
rnn_dims (int):
Number of hidden channels in RNN layers. Defaults to 512.
fc_dims (int):
Number of hidden channels in fully-conntected layers. Defaults to 512.
compute_dims (int):
Number of hidden channels in the feature ResNet. Defaults to 128.
res_out_dim (int):
Number of hidden channels in the feature ResNet output. Defaults to 128.
num_res_blocks (int):
Number of residual blocks in the ResNet. Defaults to 10.
use_aux_net (bool):
enable/disable the feature ResNet. Defaults to True.
use_upsample_net (bool):
enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True.
upsample_factors (list):
Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```.
mode (str):
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
Gaussian Distribution and `bits` for quantized bits as the model's output.
mulaw (bool):
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
to `True`.
pad (int):
Padding applied to the input feature frames against the convolution layers of the feature network.
Defaults to 2.
"""
rnn_dims: int = 512
fc_dims: int = 512
compute_dims: int = 128
res_out_dims: int = 128
num_res_blocks: int = 10
use_aux_net: bool = True
use_upsample_net: bool = True
upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8])
mode: str = "mold" # mold [string], gauss [string], bits [int]
mulaw: bool = True # apply mulaw if mode is bits
pad: int = 2
feat_dims: int = 80
class Wavernn(BaseVocoder):
def __init__(self, config: Coqpit):
"""🐸 WaveRNN model.
Original paper - https://arxiv.org/abs/1802.08435
Official implementation - https://github.com/fatchord/WaveRNN
Args:
config (Coqpit): [description]
Raises:
RuntimeError: [description]
Examples:
>>> from TTS.vocoder.configs import WavernnConfig
>>> config = WavernnConfig()
>>> model = Wavernn(config)
Paper Abstract:
Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to
both estimating the data distribution and generating high-quality samples. Efficient sampling for this
class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we
describe a set of general techniques for reducing sampling time while maintaining high output quality.
We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that
matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it
possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight
pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of
parameters, large sparse networks perform better than small dense networks and this relationship holds for
sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample
high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on
subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
orthogonal method for increasing sampling efficiency.
"""
super().__init__(config)
if isinstance(self.args.mode, int):
self.n_classes = 2**self.args.mode
elif self.args.mode == "mold":
self.n_classes = 3 * 10
elif self.args.mode == "gauss":
self.n_classes = 2
else:
raise RuntimeError("Unknown model mode value - ", self.args.mode)
self.ap = AudioProcessor(**config.audio.to_dict())
self.aux_dims = self.args.res_out_dims // 4
if self.args.use_upsample_net:
assert (
np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length
), " [!] upsample scales needs to be equal to hop_length"
self.upsample = UpsampleNetwork(
self.args.feat_dims,
self.args.upsample_factors,
self.args.compute_dims,
self.args.num_res_blocks,
self.args.res_out_dims,
self.args.pad,
self.args.use_aux_net,
)
else:
self.upsample = Upsample(
config.audio.hop_length,
self.args.pad,
self.args.num_res_blocks,
self.args.feat_dims,
self.args.compute_dims,
self.args.res_out_dims,
self.args.use_aux_net,
)
if self.args.use_aux_net:
self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims)
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True)
self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims)
self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims)
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
else:
self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims)
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims)
self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims)
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
def forward(self, x, mels):
bsize = x.size(0)
h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
mels, aux = self.upsample(mels)
if self.args.use_aux_net:
aux_idx = [self.aux_dims * i for i in range(5)]
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
x = (
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
if self.args.use_aux_net
else torch.cat([x.unsqueeze(-1), mels], dim=2)
)
x = self.I(x)
res = x
self.rnn1.flatten_parameters()
x, _ = self.rnn1(x, h1)
x = x + res
res = x
x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x
self.rnn2.flatten_parameters()
x, _ = self.rnn2(x, h2)
x = x + res
x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x
x = F.relu(self.fc1(x))
x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x
x = F.relu(self.fc2(x))
return self.fc3(x)
def inference(self, mels, batched=None, target=None, overlap=None):
self.eval()
output = []
start = time.time()
rnn1 = self.get_gru_cell(self.rnn1)
rnn2 = self.get_gru_cell(self.rnn2)
with torch.no_grad():
if isinstance(mels, np.ndarray):
mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
if mels.ndim == 2:
mels = mels.unsqueeze(0)
wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both")
mels, aux = self.upsample(mels.transpose(1, 2))
if batched:
mels = self.fold_with_overlap(mels, target, overlap)
if aux is not None:
aux = self.fold_with_overlap(aux, target, overlap)
b_size, seq_len, _ = mels.size()
h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
x = torch.zeros(b_size, 1).type_as(mels)
if self.args.use_aux_net:
d = self.aux_dims
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
for i in range(seq_len):
m_t = mels[:, i, :]
if self.args.use_aux_net:
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1)
x = self.I(x)
h1 = rnn1(x, h1)
x = x + h1
inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x
h2 = rnn2(inp, h2)
x = x + h2
x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x
x = F.relu(self.fc1(x))
x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x
x = F.relu(self.fc2(x))
logits = self.fc3(x)
if self.args.mode == "mold":
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).type_as(mels)
elif self.args.mode == "gauss":
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).type_as(mels)
elif isinstance(self.args.mode, int):
posterior = F.softmax(logits, dim=1)
distrib = torch.distributions.Categorical(posterior)
sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
output.append(sample)
x = sample.unsqueeze(-1)
else:
raise RuntimeError("Unknown model mode value - ", self.args.mode)
if i % 100 == 0:
self.gen_display(i, seq_len, b_size, start)
output = torch.stack(output).transpose(0, 1)
output = output.cpu()
if batched:
output = output.numpy()
output = output.astype(np.float64)
output = self.xfade_and_unfold(output, target, overlap)
else:
output = output[0]
if self.args.mulaw and isinstance(self.args.mode, int):
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
# Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
output = output[:wave_len]
if wave_len > len(fade_out):
output[-20 * self.config.audio.hop_length :] *= fade_out
self.train()
return output
def gen_display(self, i, seq_len, b_size, start):
gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate
stream(
"%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ",
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
)
def fold_with_overlap(self, x, target, overlap):
"""Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold()
Args:
x (tensor) : Upsampled conditioning features.
shape=(1, timesteps, features)
target (int) : Target timesteps for each index of batch
overlap (int) : Timesteps for both xfade and rnn warmup
Return:
(tensor) : shape=(num_folds, target + 2 * overlap, features)
Details:
x = [[h1, h2, ... hn]]
Where each h is a vector of conditioning features
Eg: target=2, overlap=1 with x.size(1)=10
folded = [[h1, h2, h3, h4],
[h4, h5, h6, h7],
[h7, h8, h9, h10]]
"""
_, total_len, features = x.size()
# Calculate variables needed
num_folds = (total_len - overlap) // (target + overlap)
extended_len = num_folds * (overlap + target) + overlap
remaining = total_len - extended_len
# Pad if some time steps poking out
if remaining != 0:
num_folds += 1
padding = target + 2 * overlap - remaining
x = self.pad_tensor(x, padding, side="after")
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
# Get the values for the folded tensor
for i in range(num_folds):
start = i * (target + overlap)
end = start + target + 2 * overlap
folded[i] = x[:, start:end, :]
return folded
@staticmethod
def get_gru_cell(gru):
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
gru_cell.weight_hh.data = gru.weight_hh_l0.data
gru_cell.weight_ih.data = gru.weight_ih_l0.data
gru_cell.bias_hh.data = gru.bias_hh_l0.data
gru_cell.bias_ih.data = gru.bias_ih_l0.data
return gru_cell
@staticmethod
def pad_tensor(x, pad, side="both"):
# NB - this is just a quick method i need right now
# i.e., it won't generalise to other shapes/dims
b, t, c = x.size()
total = t + 2 * pad if side == "both" else t + pad
padded = torch.zeros(b, total, c).to(x.device)
if side in ("before", "both"):
padded[:, pad : pad + t, :] = x
elif side == "after":
padded[:, :t, :] = x
return padded
@staticmethod
def xfade_and_unfold(y, target, overlap):
"""Applies a crossfade and unfolds into a 1d array.
Args:
y (ndarry) : Batched sequences of audio samples
shape=(num_folds, target + 2 * overlap)
dtype=np.float64
overlap (int) : Timesteps for both xfade and rnn warmup
Return:
(ndarry) : audio samples in a 1d array
shape=(total_len)
dtype=np.float64
Details:
y = [[seq1],
[seq2],
[seq3]]
Apply a gain envelope at both ends of the sequences
y = [[seq1_in, seq1_target, seq1_out],
[seq2_in, seq2_target, seq2_out],
[seq3_in, seq3_target, seq3_out]]
Stagger and add up the groups of samples:
[seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
"""
num_folds, length = y.shape
target = length - 2 * overlap
total_len = num_folds * (target + overlap) + overlap
# Need some silence for the rnn warmup
silence_len = overlap // 2
fade_len = overlap - silence_len
silence = np.zeros((silence_len), dtype=np.float64)
# Equal power crossfade
t = np.linspace(-1, 1, fade_len, dtype=np.float64)
fade_in = np.sqrt(0.5 * (1 + t))
fade_out = np.sqrt(0.5 * (1 - t))
# Concat the silence to the fades
fade_in = np.concatenate([silence, fade_in])
fade_out = np.concatenate([fade_out, silence])
# Apply the gain to the overlap samples
y[:, :overlap] *= fade_in
y[:, -overlap:] *= fade_out
unfolded = np.zeros((total_len), dtype=np.float64)
# Loop to add up all the samples
for i in range(num_folds):
start = i * (target + overlap)
end = start + target + 2 * overlap
unfolded[start:end] += y[i]
return unfolded
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
mels = batch["input"]
waveform = batch["waveform"]
waveform_coarse = batch["waveform_coarse"]
y_hat = self.forward(waveform, mels)
if isinstance(self.args.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
waveform_coarse = waveform_coarse.float()
waveform_coarse = waveform_coarse.unsqueeze(-1)
# compute losses
loss_dict = criterion(y_hat, waveform_coarse)
return {"model_output": y_hat}, loss_dict
def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)
@torch.no_grad()
def test(
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, Dict]:
ap = self.ap
figures = {}
audios = {}
samples = test_loader.dataset.load_test_samples(1)
for idx, sample in enumerate(samples):
x = torch.FloatTensor(sample[0])
x = x.to(next(self.parameters()).device)
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
x_hat = ap.melspectrogram(y_hat)
figures.update(
{
f"test_{idx}/ground_truth": plot_spectrogram(x.T),
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
}
)
audios.update({f"test_{idx}/audio": y_hat})
# audios.update({f"real_{idx}/audio": y_hat})
return figures, audios
def test_log(
self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
figures, audios = outputs
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@staticmethod
def format_batch(batch: Dict) -> Dict:
waveform = batch[0]
mels = batch[1]
waveform_coarse = batch[2]
return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse}
def get_data_loader( # pylint: disable=no-self-use
self,
config: Coqpit,
assets: Dict,
is_eval: True,
samples: List,
verbose: bool,
num_gpus: int,
):
ap = self.ap
dataset = WaveRNNDataset(
ap=ap,
items=samples,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad=config.model_args.pad,
mode=config.model_args.mode,
mulaw=config.model_args.mulaw,
is_training=not is_eval,
verbose=verbose,
)
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=1 if is_eval else config.batch_size,
shuffle=num_gpus == 0,
collate_fn=dataset.collate,
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=True,
)
return loader
def get_criterion(self):
# define train functions
return WaveRNNLoss(self.args.mode)
@staticmethod
def init_from_config(config: "WavernnConfig"):
return Wavernn(config)
Binary file not shown.
View File
+154
View File
@@ -0,0 +1,154 @@
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions.normal import Normal
def gaussian_loss(y_hat, y, log_std_min=-7.0):
assert y_hat.dim() == 3
assert y_hat.size(2) == 2
mean = y_hat[:, :, :1]
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
# TODO: replace with pytorch dist
log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)))
return log_probs.squeeze().mean()
def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0):
assert y_hat.size(2) == 2
mean = y_hat[:, :, :1]
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
dist = Normal(
mean,
torch.exp(log_std),
)
sample = dist.sample()
sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor)
del dist
return sample
def log_sum_exp(x):
"""numerically stable log_sum_exp implementation that prevents overflow"""
# TF ordering
axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis)
m2, _ = torch.max(x, dim=axis, keepdim=True)
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True):
if log_scale_min is None:
log_scale_min = float(np.log(1e-14))
y_hat = y_hat.permute(0, 2, 1)
assert y_hat.dim() == 3
assert y_hat.size(1) % 3 == 0
nr_mix = y_hat.size(1) // 3
# (B x T x C)
y_hat = y_hat.transpose(1, 2)
# unpack parameters. (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix : 2 * nr_mix]
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
# B x T x 1 -> B x T x num_mixtures
y = y.expand_as(means)
centered_y = y - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
cdf_plus = torch.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
cdf_min = torch.sigmoid(min_in)
# log probability for edge case of 0 (before scaling)
# equivalent: torch.log(F.sigmoid(plus_in))
log_cdf_plus = plus_in - F.softplus(plus_in)
# log probability for edge case of 255 (before scaling)
# equivalent: (1 - F.sigmoid(min_in)).log()
log_one_minus_cdf_min = -F.softplus(min_in)
# probability for all other cases
cdf_delta = cdf_plus - cdf_min
mid_in = inv_stdv * centered_y
# log probability in the center of the bin, to be used in extreme cases
# (not actually used in our code)
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
# tf equivalent
# log_probs = tf.where(x < -0.999, log_cdf_plus,
# tf.where(x > 0.999, log_one_minus_cdf_min,
# tf.where(cdf_delta > 1e-5,
# tf.log(tf.maximum(cdf_delta, 1e-12)),
# log_pdf_mid - np.log(127.5))))
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
# for num_classes=65536 case? 1e-7? not sure..
inner_inner_cond = (cdf_delta > 1e-5).float()
inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * (
log_pdf_mid - np.log((num_classes - 1) / 2)
)
inner_cond = (y > 0.999).float()
inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
cond = (y < -0.999).float()
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
log_probs = log_probs + F.log_softmax(logit_probs, -1)
if reduce:
return -torch.mean(log_sum_exp(log_probs))
return -log_sum_exp(log_probs).unsqueeze(-1)
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
"""
Sample from discretized mixture of logistic distributions
Args:
y (Tensor): :math:`[B, C, T]`
log_scale_min (float): Log scale minimum value
Returns:
Tensor: sample in range of [-1, 1].
"""
if log_scale_min is None:
log_scale_min = float(np.log(1e-14))
assert y.size(1) % 3 == 0
nr_mix = y.size(1) // 3
# B x T x C
y = y.transpose(1, 2)
logit_probs = y[:, :, :nr_mix]
# sample mixture indicator from softmax
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
temp = logit_probs.data - torch.log(-torch.log(temp))
_, argmax = temp.max(dim=-1)
# (B, T) -> (B, T, nr_mix)
one_hot = to_one_hot(argmax, nr_mix)
# select logistic parameters
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
return x
def to_one_hot(tensor, n, fill_with=1.0):
# we perform one hot encore with respect to the last axis
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor)
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
return one_hot
+72
View File
@@ -0,0 +1,72 @@
from typing import Dict
import numpy as np
import torch
from matplotlib import pyplot as plt
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
def interpolate_vocoder_input(scale_factor, spec):
"""Interpolate spectrogram by the scale factor.
It is mainly used to match the sampling rates of
the tts and vocoder models.
Args:
scale_factor (float): scale factor to interpolate the spectrogram
spec (np.array): spectrogram to be interpolated
Returns:
torch.tensor: interpolated spectrogram.
"""
print(" > before interpolation :", spec.shape)
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
spec = torch.nn.functional.interpolate(
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
).squeeze(0)
print(" > after interpolation :", spec.shape)
return spec
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
"""Plot the predicted and the real waveform and their spectrograms.
Args:
y_hat (torch.tensor): Predicted waveform.
y (torch.tensor): Real waveform.
ap (AudioProcessor): Audio processor used to process the waveform.
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
Returns:
Dict: output figures keyed by the name of the figures.
""" """Plot vocoder model results"""
if name_prefix is None:
name_prefix = ""
# select an instance from batch
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
y = y[0].squeeze().detach().cpu().numpy()
spec_fake = ap.melspectrogram(y_hat).T
spec_real = ap.melspectrogram(y).T
spec_diff = np.abs(spec_fake - spec_real)
# plot figure and save it
fig_wave = plt.figure()
plt.subplot(2, 1, 1)
plt.plot(y)
plt.title("groundtruth speech")
plt.subplot(2, 1, 2)
plt.plot(y_hat)
plt.title("generated speech")
plt.tight_layout()
plt.close()
figures = {
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
name_prefix + "speech_comparison": fig_wave,
}
return figures