Add files via upload
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
# Mozilla TTS Vocoders (Experimental)
|
||||
|
||||
Here there are vocoder model implementations which can be combined with the other TTS models.
|
||||
|
||||
Currently, following models are implemented:
|
||||
|
||||
- Melgan
|
||||
- MultiBand-Melgan
|
||||
- ParallelWaveGAN
|
||||
- GAN-TTS (Discriminator Only)
|
||||
|
||||
It is also very easy to adapt different vocoder models as we provide a flexible and modular (but not too modular) framework.
|
||||
|
||||
## Training a model
|
||||
|
||||
You can see here an example (Soon)[Colab Notebook]() training MelGAN with LJSpeech dataset.
|
||||
|
||||
In order to train a new model, you need to gather all wav files into a folder and give this folder to `data_path` in '''config.json'''
|
||||
|
||||
You need to define other relevant parameters in your ```config.json``` and then start traning with the following command.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --config_path path/to/config.json```
|
||||
|
||||
Example config files can be found under `tts/vocoder/configs/` folder.
|
||||
|
||||
You can continue a previous training run by the following command.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --continue_path path/to/your/model/folder```
|
||||
|
||||
You can fine-tune a pre-trained model by the following command.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --restore_path path/to/your/model.pth```
|
||||
|
||||
Restoring a model starts a new training in a different folder. It only restores model weights with the given checkpoint file. However, continuing a training starts from the same directory where the previous training run left off.
|
||||
|
||||
You can also follow your training runs on Tensorboard as you do with our TTS models.
|
||||
|
||||
## Acknowledgement
|
||||
Thanks to @kan-bayashi for his [repository](https://github.com/kan-bayashi/ParallelWaveGAN) being the start point of our work.
|
||||
Binary file not shown.
@@ -0,0 +1,17 @@
|
||||
import importlib
|
||||
import os
|
||||
from inspect import isclass
|
||||
|
||||
# import all files under configs/
|
||||
configs_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(configs_dir):
|
||||
path = os.path.join(configs_dir, file)
|
||||
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
||||
config_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
module = importlib.import_module("TTS.vocoder.configs." + config_name)
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
|
||||
if isclass(attribute):
|
||||
# Add the class to this package's variables
|
||||
globals()[attribute_name] = attribute
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,106 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullbandMelganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for FullBand MelGAN vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import FullbandMelganConfig
|
||||
>>> config = FullbandMelganConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `fullband_melgan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'melgan_multiscale_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model parameters. Defaults to
|
||||
'{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}`
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `melgan_generator`.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict): STFT loss parameters. Default to
|
||||
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "fullband_melgan"
|
||||
|
||||
# Model specific params
|
||||
discriminator_model: str = "melgan_multiscale_discriminator"
|
||||
discriminator_model_params: dict = field(
|
||||
default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]}
|
||||
)
|
||||
generator_model: str = "melgan_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 4}
|
||||
)
|
||||
|
||||
# Training - overrides
|
||||
batch_size: int = 16
|
||||
seq_len: int = 8192
|
||||
pad_short: int = 2000
|
||||
use_noise_augment: bool = True
|
||||
use_cache: bool = True
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0.5
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 2.5
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 108
|
||||
l1_spec_loss_weight: float = 0.0
|
||||
@@ -0,0 +1,136 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class HifiganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for FullBand MelGAN vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import HifiganConfig
|
||||
>>> config = HifiganConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `hifigan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'hifigan_discriminator`.
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `hifigan_generator`.
|
||||
generator_model_params (dict): Parameters of the generator model. Defaults to
|
||||
`
|
||||
{
|
||||
"upsample_factors": [8, 8, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 4, 4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_type": "1",
|
||||
}
|
||||
`
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict):
|
||||
STFT loss parameters. Default to
|
||||
`{
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240]
|
||||
}`
|
||||
l1_spec_loss_params (dict):
|
||||
L1 spectrogram loss parameters. Default to
|
||||
`{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "hifigan"
|
||||
# model specific params
|
||||
discriminator_model: str = "hifigan_discriminator"
|
||||
generator_model: str = "hifigan_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"upsample_factors": [8, 8, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 4, 4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_type": "1",
|
||||
}
|
||||
)
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = False
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = True
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 1
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 108
|
||||
l1_spec_loss_weight: float = 45
|
||||
l1_spec_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
)
|
||||
|
||||
# optimizer parameters
|
||||
lr: float = 1e-4
|
||||
wd: float = 1e-6
|
||||
@@ -0,0 +1,106 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MelganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for MelGAN vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import MelganConfig
|
||||
>>> config = MelganConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `melgan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'melgan_multiscale_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model parameters. Defaults to
|
||||
'{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}`
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `melgan_generator`.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict): STFT loss parameters. Default to
|
||||
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "melgan"
|
||||
|
||||
# Model specific params
|
||||
discriminator_model: str = "melgan_multiscale_discriminator"
|
||||
discriminator_model_params: dict = field(
|
||||
default_factory=lambda: {"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}
|
||||
)
|
||||
generator_model: str = "melgan_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 3}
|
||||
)
|
||||
|
||||
# Training - overrides
|
||||
batch_size: int = 16
|
||||
seq_len: int = 8192
|
||||
pad_short: int = 2000
|
||||
use_noise_augment: bool = True
|
||||
use_cache: bool = True
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0.5
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 2.5
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 108
|
||||
l1_spec_loss_weight: float = 0
|
||||
@@ -0,0 +1,144 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultibandMelganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for MultiBandMelGAN vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import MultibandMelganConfig
|
||||
>>> config = MultibandMelganConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `multiband_melgan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'melgan_multiscale_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model parameters. Defaults to
|
||||
'{
|
||||
"base_channels": 16,
|
||||
"max_channels": 512,
|
||||
"downsample_factors": [4, 4, 4]
|
||||
}`
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `melgan_generator`.
|
||||
generator_model_param (dict):
|
||||
The generator model parameters. Defaults to `{"upsample_factors": [8, 4, 2], "num_res_blocks": 4}`.
|
||||
use_pqmf (bool):
|
||||
enable / disable PQMF modulation for multi-band training. Defaults to True.
|
||||
lr_gen (float):
|
||||
Initial learning rate for the generator model. Defaults to 0.0001.
|
||||
lr_disc (float):
|
||||
Initial learning rate for the discriminator model. Defaults to 0.0001.
|
||||
optimizer (torch.optim.Optimizer):
|
||||
Optimizer used for the training. Defaults to `AdamW`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
lr_scheduler_gen (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the generator. Defaults to `MultiStepLR`.
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to
|
||||
`{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`.
|
||||
lr_scheduler_disc (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the discriminator. Defaults to `MultiStepLR`.
|
||||
lr_scheduler_dict_params (dict):
|
||||
Parameters for the discriminator learning rate scheduler. Defaults to
|
||||
`{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
steps_to_start_discriminator (int):
|
||||
Number of steps required to start training the discriminator. Defaults to 0.
|
||||
use_stft_loss (bool):`
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict): STFT loss parameters. Default to
|
||||
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "multiband_melgan"
|
||||
|
||||
# Model specific params
|
||||
discriminator_model: str = "melgan_multiscale_discriminator"
|
||||
discriminator_model_params: dict = field(
|
||||
default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]}
|
||||
)
|
||||
generator_model: str = "multiband_melgan_generator"
|
||||
generator_model_params: dict = field(default_factory=lambda: {"upsample_factors": [8, 4, 2], "num_res_blocks": 4})
|
||||
use_pqmf: bool = True
|
||||
|
||||
# optimizer - overrides
|
||||
lr_gen: float = 0.0001 # Initial learning rate.
|
||||
lr_disc: float = 0.0001 # Initial learning rate.
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
|
||||
lr_scheduler_gen: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_gen_params: dict = field(
|
||||
default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}
|
||||
)
|
||||
lr_scheduler_disc: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_disc_params: dict = field(
|
||||
default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}
|
||||
)
|
||||
|
||||
# Training - overrides
|
||||
batch_size: int = 64
|
||||
seq_len: int = 16384
|
||||
pad_short: int = 2000
|
||||
use_noise_augment: bool = False
|
||||
use_cache: bool = True
|
||||
steps_to_start_discriminator: bool = 200000
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = True
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
subband_stft_loss_params: dict = field(
|
||||
default_factory=lambda: {"n_ffts": [384, 683, 171], "hop_lengths": [30, 60, 10], "win_lengths": [150, 300, 60]}
|
||||
)
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0.5
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 2.5
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 108
|
||||
l1_spec_loss_weight: float = 0
|
||||
@@ -0,0 +1,134 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelWaveganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for ParallelWavegan vocoder.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right configuration at initialization. Defaults to `gan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'parallel_wavegan_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model kwargs. Defaults to
|
||||
'{"num_layers": 10}`
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `parallel_wavegan_generator`.
|
||||
generator_model_param (dict):
|
||||
The generator model kwargs. Defaults to `{"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30}`.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
steps_to_start_discriminator (int):
|
||||
Number of steps required to start training the discriminator. Defaults to 0.
|
||||
use_stft_loss (bool):`
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict): STFT loss parameters. Default to
|
||||
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 0.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
lr_gen (float):
|
||||
Generator model initial learning rate. Defaults to 0.0002.
|
||||
lr_disc (float):
|
||||
Discriminator model initial learning rate. Defaults to 0.0002.
|
||||
optimizer (torch.optim.Optimizer):
|
||||
Optimizer used for the training. Defaults to `AdamW`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
lr_scheduler_gen (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the generator. Defaults to `ExponentialLR`.
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`.
|
||||
lr_scheduler_disc (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`.
|
||||
lr_scheduler_dict_params (dict):
|
||||
Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`.
|
||||
"""
|
||||
|
||||
model: str = "parallel_wavegan"
|
||||
|
||||
# Model specific params
|
||||
discriminator_model: str = "parallel_wavegan_discriminator"
|
||||
discriminator_model_params: dict = field(default_factory=lambda: {"num_layers": 10})
|
||||
generator_model: str = "parallel_wavegan_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30}
|
||||
)
|
||||
|
||||
# Training - overrides
|
||||
batch_size: int = 6
|
||||
seq_len: int = 25600
|
||||
pad_short: int = 2000
|
||||
use_noise_augment: bool = False
|
||||
use_cache: bool = True
|
||||
steps_to_start_discriminator: int = 200000
|
||||
target_loss: str = "loss_1"
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0.5
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 2.5
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 0
|
||||
l1_spec_loss_weight: float = 0
|
||||
|
||||
# optimizer overrides
|
||||
lr_gen: float = 0.0002 # Initial learning rate.
|
||||
lr_disc: float = 0.0002 # Initial learning rate.
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
|
||||
lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_disc_params: dict = field(
|
||||
default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}
|
||||
)
|
||||
scheduler_after_epoch: bool = False
|
||||
@@ -0,0 +1,182 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.config import BaseAudioConfig, BaseTrainingConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVocoderConfig(BaseTrainingConfig):
|
||||
"""Shared parameters among all the vocoder models.
|
||||
Args:
|
||||
audio (BaseAudioConfig):
|
||||
Audio processor config instance. Defaultsto `BaseAudioConfig()`.
|
||||
use_noise_augment (bool):
|
||||
Augment the input audio with random noise. Defaults to False/
|
||||
eval_split_size (int):
|
||||
Number of instances used for evaluation. Defaults to 10.
|
||||
data_path (str):
|
||||
Root path of the training data. All the audio files found recursively from this root path are used for
|
||||
training. Defaults to `""`.
|
||||
feature_path (str):
|
||||
Root path to the precomputed feature files. Defaults to None.
|
||||
seq_len (int):
|
||||
Length of the waveform segments used for training. Defaults to 1000.
|
||||
pad_short (int):
|
||||
Extra padding for the waveforms shorter than `seq_len`. Defaults to 0.
|
||||
conv_path (int):
|
||||
Extra padding for the feature frames against convolution of the edge frames. Defaults to MISSING.
|
||||
Defaults to 0.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. If the RAM is not enough, if may cause OOM.
|
||||
Defaults to False.
|
||||
epochs (int):
|
||||
Number of training epochs to. Defaults to 10000.
|
||||
wd (float):
|
||||
Weight decay.
|
||||
optimizer (torch.optim.Optimizer):
|
||||
Optimizer used for the training. Defaults to `AdamW`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
# dataloading
|
||||
use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms.
|
||||
eval_split_size: int = 10 # number of samples used for evaluation.
|
||||
# dataset
|
||||
data_path: str = "" # root data path. It finds all wav files recursively from there.
|
||||
feature_path: str = None # if you use precomputed features
|
||||
seq_len: int = 1000 # signal length used in training.
|
||||
pad_short: int = 0 # additional padding for short wavs
|
||||
conv_pad: int = 0 # additional padding against convolutions applied to spectrograms
|
||||
use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM.
|
||||
# OPTIMIZER
|
||||
epochs: int = 10000 # total number of epochs to train.
|
||||
wd: float = 0.0 # Weight decay weight.
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseGANVocoderConfig(BaseVocoderConfig):
|
||||
"""Base config class used among all the GAN based vocoders.
|
||||
Args:
|
||||
use_stft_loss (bool):
|
||||
enable / disable the use of STFT loss. Defaults to True.
|
||||
use_subband_stft_loss (bool):
|
||||
enable / disable the use of Subband STFT loss. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable the use of Mean Squared Error based GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable the use of Hinge GAN loss. Defaults to True.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable feature matching loss. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable L1 spectrogram loss. Defaults to True.
|
||||
stft_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 0.
|
||||
subband_stft_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 1.
|
||||
hinge_G_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 100.
|
||||
l1_spec_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 45.
|
||||
stft_loss_params (dict):
|
||||
Parameters for the STFT loss. Defaults to `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`.
|
||||
l1_spec_loss_params (dict):
|
||||
Parameters for the L1 spectrogram loss. Defaults to
|
||||
`{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}`
|
||||
target_loss (str):
|
||||
Target loss name that defines the quality of the model. Defaults to `G_avg_loss`.
|
||||
grad_clip (list):
|
||||
A list of gradient clipping theresholds for each optimizer. Any value less than 0 disables clipping.
|
||||
Defaults to [5, 5].
|
||||
lr_gen (float):
|
||||
Generator model initial learning rate. Defaults to 0.0002.
|
||||
lr_disc (float):
|
||||
Discriminator model initial learning rate. Defaults to 0.0002.
|
||||
lr_scheduler_gen (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the generator. Defaults to `ExponentialLR`.
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
|
||||
lr_scheduler_disc (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`.
|
||||
lr_scheduler_disc_params (dict):
|
||||
Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
|
||||
scheduler_after_epoch (bool):
|
||||
Whether to update the learning rate schedulers after each epoch. Defaults to True.
|
||||
use_pqmf (bool):
|
||||
enable / disable PQMF for subband approximation at training. Defaults to False.
|
||||
steps_to_start_discriminator (int):
|
||||
Number of steps required to start training the discriminator. Defaults to 0.
|
||||
diff_samples_for_G_and_D (bool):
|
||||
enable / disable use of different training samples for the generator and the discriminator iterations.
|
||||
Enabling it results in slower iterations but faster convergance in some cases. Defaults to False.
|
||||
"""
|
||||
|
||||
model: str = "gan"
|
||||
|
||||
# LOSS PARAMETERS
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = True
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = True
|
||||
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = True
|
||||
|
||||
# loss weights
|
||||
stft_loss_weight: float = 0
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 1
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 100
|
||||
l1_spec_loss_weight: float = 45
|
||||
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
|
||||
l1_spec_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
)
|
||||
|
||||
target_loss: str = "loss_0" # loss value to pick the best model to save after each epoch
|
||||
|
||||
# optimizer
|
||||
grad_clip: float = field(default_factory=lambda: [5, 5])
|
||||
lr_gen: float = 0.0002 # Initial learning rate.
|
||||
lr_disc: float = 0.0002 # Initial learning rate.
|
||||
lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
scheduler_after_epoch: bool = True
|
||||
|
||||
use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN)
|
||||
steps_to_start_discriminator = 0 # start training the discriminator after this number of steps.
|
||||
diff_samples_for_G_and_D: bool = False # use different samples for G and D training steps.
|
||||
@@ -0,0 +1,161 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnivnetConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for UnivNet vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import UnivNetConfig
|
||||
>>> config = UnivNetConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `UnivNet`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'UnivNet_discriminator`.
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `UnivNet_generator`.
|
||||
generator_model_params (dict): Parameters of the generator model. Defaults to
|
||||
`
|
||||
{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
`
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 32.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by univnet model. Defaults to False.
|
||||
stft_loss_params (dict):
|
||||
STFT loss parameters. Default to
|
||||
`{
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240]
|
||||
}`
|
||||
l1_spec_loss_params (dict):
|
||||
L1 spectrogram loss parameters. Default to
|
||||
`{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "univnet"
|
||||
batch_size: int = 32
|
||||
# model specific params
|
||||
discriminator_model: str = "univnet_discriminator"
|
||||
generator_model: str = "univnet_generator"
|
||||
generator_model_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"in_channels": 64,
|
||||
"out_channels": 1,
|
||||
"hidden_channels": 32,
|
||||
"cond_channels": 80,
|
||||
"upsample_factors": [8, 8, 4],
|
||||
"lvc_layers_each_block": 4,
|
||||
"lvc_kernel_size": 3,
|
||||
"kpnet_hidden_channels": 64,
|
||||
"kpnet_conv_size": 3,
|
||||
"dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and univnet)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 2.5
|
||||
stft_loss_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 1
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 0
|
||||
l1_spec_loss_weight: float = 0
|
||||
l1_spec_loss_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
)
|
||||
|
||||
# optimizer parameters
|
||||
lr_gen: float = 1e-4 # Initial learning rate.
|
||||
lr_disc: float = 1e-4 # Initial learning rate.
|
||||
lr_scheduler_gen: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
# lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
# lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
|
||||
steps_to_start_discriminator: int = 200000
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.generator_model_params["cond_channels"] = self.audio.num_mels
|
||||
@@ -0,0 +1,90 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseVocoderConfig
|
||||
from TTS.vocoder.models.wavegrad import WavegradArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavegradConfig(BaseVocoderConfig):
|
||||
"""Defines parameters for WaveGrad vocoder.
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import WavegradConfig
|
||||
>>> config = WavegradConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `wavegrad`.
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `wavegrad`.
|
||||
model_params (WavegradArgs): Model parameters. Check `WavegradArgs` for default values.
|
||||
target_loss (str):
|
||||
Target loss name that defines the quality of the model. Defaults to `avg_wavegrad_loss`.
|
||||
epochs (int):
|
||||
Number of epochs to traing the model. Defaults to 10000.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 96.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 6144.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
mixed_precision (bool):
|
||||
enable / disable mixed precision training. Default is True.
|
||||
eval_split_size (int):
|
||||
Number of samples used for evalutaion. Defaults to 50.
|
||||
train_noise_schedule (dict):
|
||||
Training noise schedule. Defaults to
|
||||
`{"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000}`
|
||||
test_noise_schedule (dict):
|
||||
Inference noise schedule. For a better performance, you may need to use `bin/tune_wavegrad.py` to find a
|
||||
better schedule. Defaults to
|
||||
`
|
||||
{
|
||||
"min_val": 1e-6,
|
||||
"max_val": 1e-2,
|
||||
"num_steps": 50,
|
||||
}
|
||||
`
|
||||
grad_clip (float):
|
||||
Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 1.0
|
||||
lr (float):
|
||||
Initila leraning rate. Defaults to 1e-4.
|
||||
lr_scheduler (str):
|
||||
One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`.
|
||||
lr_scheduler_params (dict):
|
||||
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`
|
||||
"""
|
||||
|
||||
model: str = "wavegrad"
|
||||
# Model specific params
|
||||
generator_model: str = "wavegrad"
|
||||
model_params: WavegradArgs = field(default_factory=WavegradArgs)
|
||||
target_loss: str = "loss" # loss value to pick the best model to save after each epoch
|
||||
|
||||
# Training - overrides
|
||||
epochs: int = 10000
|
||||
batch_size: int = 96
|
||||
seq_len: int = 6144
|
||||
use_cache: bool = True
|
||||
mixed_precision: bool = True
|
||||
eval_split_size: int = 50
|
||||
|
||||
# NOISE SCHEDULE PARAMS
|
||||
train_noise_schedule: dict = field(default_factory=lambda: {"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000})
|
||||
|
||||
test_noise_schedule: dict = field(
|
||||
default_factory=lambda: { # inference noise schedule. Try TTS/bin/tune_wavegrad.py to find the optimal values.
|
||||
"min_val": 1e-6,
|
||||
"max_val": 1e-2,
|
||||
"num_steps": 50,
|
||||
}
|
||||
)
|
||||
|
||||
# optimizer overrides
|
||||
grad_clip: float = 1.0
|
||||
lr: float = 1e-4 # Initial learning rate.
|
||||
lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_params: dict = field(
|
||||
default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}
|
||||
)
|
||||
@@ -0,0 +1,102 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseVocoderConfig
|
||||
from TTS.vocoder.models.wavernn import WavernnArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavernnConfig(BaseVocoderConfig):
|
||||
"""Defines parameters for Wavernn vocoder.
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import WavernnConfig
|
||||
>>> config = WavernnConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `wavernn`.
|
||||
mode (str):
|
||||
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
|
||||
Gaussian Distribution and `bits` for quantized bits as the model's output.
|
||||
mulaw (bool):
|
||||
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
|
||||
to `True`.
|
||||
generator_model (str):
|
||||
One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `WaveRNN`.
|
||||
wavernn_model_params (dict):
|
||||
kwargs for the WaveRNN model. Defaults to
|
||||
`{
|
||||
"rnn_dims": 512,
|
||||
"fc_dims": 512,
|
||||
"compute_dims": 128,
|
||||
"res_out_dims": 128,
|
||||
"num_res_blocks": 10,
|
||||
"use_aux_net": True,
|
||||
"use_upsample_net": True,
|
||||
"upsample_factors": [4, 8, 8]
|
||||
}`
|
||||
batched (bool):
|
||||
enable / disable the batched inference. It speeds up the inference by splitting the input into segments and
|
||||
processing the segments in a batch. Then it merges the outputs with a certain overlap and smoothing. If
|
||||
you set it False, without CUDA, it is too slow to be practical. Defaults to True.
|
||||
target_samples (int):
|
||||
Size of the segments in batched mode. Defaults to 11000.
|
||||
overlap_sampels (int):
|
||||
Size of the overlap between consecutive segments. Defaults to 550.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 256.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 1280.
|
||||
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
mixed_precision (bool):
|
||||
enable / disable mixed precision training. Default is True.
|
||||
eval_split_size (int):
|
||||
Number of samples used for evalutaion. Defaults to 50.
|
||||
num_epochs_before_test (int):
|
||||
Number of epochs waited to run the next evalution. Since inference takes some time, it is better to
|
||||
wait some number of epochs not ot waste training time. Defaults to 10.
|
||||
grad_clip (float):
|
||||
Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 4.0
|
||||
lr (float):
|
||||
Initila leraning rate. Defaults to 1e-4.
|
||||
lr_scheduler (str):
|
||||
One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`.
|
||||
lr_scheduler_params (dict):
|
||||
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [200000, 400000, 600000]}`
|
||||
"""
|
||||
|
||||
model: str = "wavernn"
|
||||
|
||||
# Model specific params
|
||||
model_args: WavernnArgs = field(default_factory=WavernnArgs)
|
||||
target_loss: str = "loss"
|
||||
|
||||
# Inference
|
||||
batched: bool = True
|
||||
target_samples: int = 11000
|
||||
overlap_samples: int = 550
|
||||
|
||||
# Training - overrides
|
||||
epochs: int = 10000
|
||||
batch_size: int = 256
|
||||
seq_len: int = 1280
|
||||
use_noise_augment: bool = False
|
||||
use_cache: bool = True
|
||||
mixed_precision: bool = True
|
||||
eval_split_size: int = 50
|
||||
num_epochs_before_test: int = (
|
||||
10 # number of epochs to wait until the next test run (synthesizing a full audio clip).
|
||||
)
|
||||
|
||||
# optimizer overrides
|
||||
grad_clip: float = 4.0
|
||||
lr: float = 1e-4 # Initial learning rate.
|
||||
lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {"gamma": 0.5, "milestones": [200000, 400000, 600000]})
|
||||
@@ -0,0 +1,58 @@
|
||||
from typing import List
|
||||
|
||||
from coqpit import Coqpit
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
|
||||
|
||||
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
|
||||
if config.model.lower() in "gan":
|
||||
dataset = GANDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||
is_training=not is_eval,
|
||||
return_segments=not is_eval,
|
||||
use_noise_augment=config.use_noise_augment,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
elif config.model.lower() == "wavegrad":
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
is_training=not is_eval,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
elif config.model.lower() == "wavernn":
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=config.model_params.pad,
|
||||
mode=config.model_params.mode,
|
||||
mulaw=config.model_params.mulaw,
|
||||
is_training=not is_eval,
|
||||
verbose=verbose,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")
|
||||
return dataset
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,152 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from multiprocessing import Manager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class GANDataset(Dataset):
|
||||
"""
|
||||
GAN Dataset searchs for all the wav files under root path
|
||||
and converts them to acoustic features on the fly and returns
|
||||
random segments of (audio, feature) couples.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
return_pairs=False,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
self.item_list = items
|
||||
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||
self.seq_len = seq_len
|
||||
self.hop_len = hop_len
|
||||
self.pad_short = pad_short
|
||||
self.conv_pad = conv_pad
|
||||
self.return_pairs = return_pairs
|
||||
self.is_training = is_training
|
||||
self.return_segments = return_segments
|
||||
self.use_cache = use_cache
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
|
||||
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
|
||||
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
|
||||
|
||||
# map G and D instances
|
||||
self.G_to_D_mappings = list(range(len(self.item_list)))
|
||||
self.shuffle_mapping()
|
||||
|
||||
# cache acoustic features
|
||||
if use_cache:
|
||||
self.create_feature_cache()
|
||||
|
||||
def create_feature_cache(self):
|
||||
self.manager = Manager()
|
||||
self.cache = self.manager.list()
|
||||
self.cache += [None for _ in range(len(self.item_list))]
|
||||
|
||||
@staticmethod
|
||||
def find_wav_files(path):
|
||||
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return different items for Generator and Discriminator and
|
||||
cache acoustic features"""
|
||||
|
||||
# set the seed differently for each worker
|
||||
if torch.utils.data.get_worker_info():
|
||||
random.seed(torch.utils.data.get_worker_info().seed)
|
||||
|
||||
if self.return_segments:
|
||||
item1 = self.load_item(idx)
|
||||
if self.return_pairs:
|
||||
idx2 = self.G_to_D_mappings[idx]
|
||||
item2 = self.load_item(idx2)
|
||||
return item1, item2
|
||||
return item1
|
||||
item1 = self.load_item(idx)
|
||||
return item1
|
||||
|
||||
def _pad_short_samples(self, audio, mel=None):
|
||||
"""Pad samples shorter than the output sequence length"""
|
||||
if len(audio) < self.seq_len:
|
||||
audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0)
|
||||
|
||||
if mel is not None and mel.shape[1] < self.feat_frame_len:
|
||||
pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0]
|
||||
mel = np.pad(
|
||||
mel,
|
||||
([0, 0], [0, self.feat_frame_len - mel.shape[1]]),
|
||||
mode="constant",
|
||||
constant_values=pad_value.mean(),
|
||||
)
|
||||
return audio, mel
|
||||
|
||||
def shuffle_mapping(self):
|
||||
random.shuffle(self.G_to_D_mappings)
|
||||
|
||||
def load_item(self, idx):
|
||||
"""load (audio, feat) couple"""
|
||||
if self.compute_feat:
|
||||
# compute features from wav
|
||||
wavpath = self.item_list[idx]
|
||||
# print(wavpath)
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio, mel = self.cache[idx]
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
audio, mel = self._pad_short_samples(audio, mel)
|
||||
else:
|
||||
# load precomputed features
|
||||
wavpath, feat_path = self.item_list[idx]
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio, mel = self.cache[idx]
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = np.load(feat_path)
|
||||
audio, mel = self._pad_short_samples(audio, mel)
|
||||
|
||||
# correct the audio length wrt padding applied in stft
|
||||
audio = np.pad(audio, (0, self.hop_len), mode="edge")
|
||||
audio = audio[: mel.shape[-1] * self.hop_len]
|
||||
assert (
|
||||
mel.shape[-1] * self.hop_len == audio.shape[-1]
|
||||
), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}"
|
||||
|
||||
audio = torch.from_numpy(audio).float().unsqueeze(0)
|
||||
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||
|
||||
if self.return_segments:
|
||||
max_mel_start = mel.shape[1] - self.feat_frame_len
|
||||
mel_start = random.randint(0, max_mel_start)
|
||||
mel_end = mel_start + self.feat_frame_len
|
||||
mel = mel[:, mel_start:mel_end]
|
||||
|
||||
audio_start = mel_start * self.hop_len
|
||||
audio = audio[:, audio_start : audio_start + self.seq_len]
|
||||
|
||||
if self.use_noise_augment and self.is_training and self.return_segments:
|
||||
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
||||
return (mel, audio)
|
||||
@@ -0,0 +1,75 @@
|
||||
import glob
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from coqpit import Coqpit
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||
"""Process wav and compute mel and quantized wave signal.
|
||||
It is mainly used by WaveRNN dataloader.
|
||||
|
||||
Args:
|
||||
out_path (str): Parent folder path to save the files.
|
||||
config (Coqpit): Model config.
|
||||
ap (AudioProcessor): Audio processor.
|
||||
"""
|
||||
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
||||
wav_files = find_wav_files(config.data_path)
|
||||
for path in tqdm(wav_files):
|
||||
wav_name = Path(path).stem
|
||||
quant_path = os.path.join(out_path, "quant", wav_name + ".npy")
|
||||
mel_path = os.path.join(out_path, "mel", wav_name + ".npy")
|
||||
y = ap.load_wav(path)
|
||||
mel = ap.melspectrogram(y)
|
||||
np.save(mel_path, mel)
|
||||
if isinstance(config.mode, int):
|
||||
quant = (
|
||||
mulaw_encode(wav=y, mulaw_qc=config.mode)
|
||||
if config.model_args.mulaw
|
||||
else quantize(x=y, quantize_bits=config.mode)
|
||||
)
|
||||
np.save(quant_path, quant)
|
||||
|
||||
|
||||
def find_wav_files(data_path, file_ext="wav"):
|
||||
wav_paths = glob.glob(os.path.join(data_path, "**", f"*.{file_ext}"), recursive=True)
|
||||
return wav_paths
|
||||
|
||||
|
||||
def find_feat_files(data_path):
|
||||
feat_paths = glob.glob(os.path.join(data_path, "**", "*.npy"), recursive=True)
|
||||
return feat_paths
|
||||
|
||||
|
||||
def load_wav_data(data_path, eval_split_size, file_ext="wav"):
|
||||
wav_paths = find_wav_files(data_path, file_ext=file_ext)
|
||||
assert len(wav_paths) > 0, f" [!] {data_path} is empty."
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(wav_paths)
|
||||
return wav_paths[:eval_split_size], wav_paths[eval_split_size:]
|
||||
|
||||
|
||||
def load_wav_feat_data(data_path, feat_path, eval_split_size):
|
||||
wav_paths = find_wav_files(data_path)
|
||||
feat_paths = find_feat_files(feat_path)
|
||||
|
||||
wav_paths.sort(key=lambda x: Path(x).stem)
|
||||
feat_paths.sort(key=lambda x: Path(x).stem)
|
||||
|
||||
assert len(wav_paths) == len(feat_paths), f" [!] {len(wav_paths)} vs {feat_paths}"
|
||||
for wav, feat in zip(wav_paths, feat_paths):
|
||||
wav_name = Path(wav).stem
|
||||
feat_name = Path(feat).stem
|
||||
assert wav_name == feat_name
|
||||
|
||||
items = list(zip(wav_paths, feat_paths))
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
@@ -0,0 +1,151 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from multiprocessing import Manager
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class WaveGradDataset(Dataset):
|
||||
"""
|
||||
WaveGrad Dataset searchs for all the wav files under root path
|
||||
and converts them to acoustic features on the fly and returns
|
||||
random segments of (audio, feature) couples.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
self.item_list = items
|
||||
self.seq_len = seq_len if return_segments else None
|
||||
self.hop_len = hop_len
|
||||
self.pad_short = pad_short
|
||||
self.conv_pad = conv_pad
|
||||
self.is_training = is_training
|
||||
self.return_segments = return_segments
|
||||
self.use_cache = use_cache
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
|
||||
if return_segments:
|
||||
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
|
||||
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
|
||||
|
||||
# cache acoustic features
|
||||
if use_cache:
|
||||
self.create_feature_cache()
|
||||
|
||||
def create_feature_cache(self):
|
||||
self.manager = Manager()
|
||||
self.cache = self.manager.list()
|
||||
self.cache += [None for _ in range(len(self.item_list))]
|
||||
|
||||
@staticmethod
|
||||
def find_wav_files(path):
|
||||
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.load_item(idx)
|
||||
return item
|
||||
|
||||
def load_test_samples(self, num_samples: int) -> List[Tuple]:
|
||||
"""Return test samples.
|
||||
|
||||
Args:
|
||||
num_samples (int): Number of samples to return.
|
||||
|
||||
Returns:
|
||||
List[Tuple]: melspectorgram and audio.
|
||||
|
||||
Shapes:
|
||||
- melspectrogram (Tensor): :math:`[C, T]`
|
||||
- audio (Tensor): :math:`[T_audio]`
|
||||
"""
|
||||
samples = []
|
||||
return_segments = self.return_segments
|
||||
self.return_segments = False
|
||||
for idx in range(num_samples):
|
||||
mel, audio = self.load_item(idx)
|
||||
samples.append([mel, audio])
|
||||
self.return_segments = return_segments
|
||||
return samples
|
||||
|
||||
def load_item(self, idx):
|
||||
"""load (audio, feat) couple"""
|
||||
# compute features from wav
|
||||
wavpath = self.item_list[idx]
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio = self.cache[idx]
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
|
||||
if self.return_segments:
|
||||
# correct audio length wrt segment length
|
||||
if audio.shape[-1] < self.seq_len + self.pad_short:
|
||||
audio = np.pad(
|
||||
audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0
|
||||
)
|
||||
assert (
|
||||
audio.shape[-1] >= self.seq_len + self.pad_short
|
||||
), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
|
||||
|
||||
# correct the audio length wrt hop length
|
||||
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
|
||||
audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0)
|
||||
|
||||
if self.use_cache:
|
||||
self.cache[idx] = audio
|
||||
|
||||
if self.return_segments:
|
||||
max_start = len(audio) - self.seq_len
|
||||
start = random.randint(0, max_start)
|
||||
end = start + self.seq_len
|
||||
audio = audio[start:end]
|
||||
|
||||
if self.use_noise_augment and self.is_training and self.return_segments:
|
||||
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
||||
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
mel = mel[..., :-1] # ignore the padding
|
||||
|
||||
audio = torch.from_numpy(audio).float()
|
||||
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||
return (mel, audio)
|
||||
|
||||
@staticmethod
|
||||
def collate_full_clips(batch):
|
||||
"""This is used in tune_wavegrad.py.
|
||||
It pads sequences to the max length."""
|
||||
max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]
|
||||
max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0]
|
||||
|
||||
mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length])
|
||||
audios = torch.zeros([len(batch), max_audio_length])
|
||||
|
||||
for idx, b in enumerate(batch):
|
||||
mel = b[0]
|
||||
audio = b[1]
|
||||
mels[idx, :, : mel.shape[1]] = mel
|
||||
audios[idx, : audio.shape[0]] = audio
|
||||
|
||||
return mels, audios
|
||||
@@ -0,0 +1,118 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
class WaveRNNDataset(Dataset):
|
||||
"""
|
||||
WaveRNN Dataset searchs for all the wav files under root path
|
||||
and converts them to acoustic features on the fly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
|
||||
):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||
self.item_list = items
|
||||
self.seq_len = seq_len
|
||||
self.hop_len = hop_len
|
||||
self.mel_len = seq_len // hop_len
|
||||
self.pad = pad
|
||||
self.mode = mode
|
||||
self.mulaw = mulaw
|
||||
self.is_training = is_training
|
||||
self.verbose = verbose
|
||||
self.return_segments = return_segments
|
||||
|
||||
assert self.seq_len % self.hop_len == 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.load_item(index)
|
||||
return item
|
||||
|
||||
def load_test_samples(self, num_samples):
|
||||
samples = []
|
||||
return_segments = self.return_segments
|
||||
self.return_segments = False
|
||||
for idx in range(num_samples):
|
||||
mel, audio, _ = self.load_item(idx)
|
||||
samples.append([mel, audio])
|
||||
self.return_segments = return_segments
|
||||
return samples
|
||||
|
||||
def load_item(self, index):
|
||||
"""
|
||||
load (audio, feat) couple if feature_path is set
|
||||
else compute it on the fly
|
||||
"""
|
||||
if self.compute_feat:
|
||||
wavpath = self.item_list[index]
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
if self.return_segments:
|
||||
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
|
||||
else:
|
||||
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
|
||||
if audio.shape[0] < min_audio_len:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
|
||||
if self.mode in ["gauss", "mold"]:
|
||||
x_input = audio
|
||||
elif isinstance(self.mode, int):
|
||||
x_input = (
|
||||
mulaw_encode(wav=audio, mulaw_qc=self.mode)
|
||||
if self.mulaw
|
||||
else quantize(x=audio, quantize_bits=self.mode)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||
|
||||
else:
|
||||
wavpath, feat_path = self.item_list[index]
|
||||
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||
|
||||
if mel.shape[-1] < self.mel_len + 2 * self.pad:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
self.item_list[index] = self.item_list[index + 1]
|
||||
feat_path = self.item_list[index]
|
||||
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||
if self.mode in ["gauss", "mold"]:
|
||||
x_input = self.ap.load_wav(wavpath)
|
||||
elif isinstance(self.mode, int):
|
||||
x_input = np.load(feat_path.replace("/mel/", "/quant/"))
|
||||
else:
|
||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||
|
||||
return mel, x_input, wavpath
|
||||
|
||||
def collate(self, batch):
|
||||
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
||||
max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch]
|
||||
|
||||
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
||||
sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets]
|
||||
|
||||
mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
|
||||
|
||||
coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)]
|
||||
|
||||
mels = np.stack(mels).astype(np.float32)
|
||||
if self.mode in ["gauss", "mold"]:
|
||||
coarse = np.stack(coarse).astype(np.float32)
|
||||
coarse = torch.FloatTensor(coarse)
|
||||
x_input = coarse[:, : self.seq_len]
|
||||
elif isinstance(self.mode, int):
|
||||
coarse = np.stack(coarse).astype(np.int64)
|
||||
coarse = torch.LongTensor(coarse)
|
||||
x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0
|
||||
y_coarse = coarse[:, 1:]
|
||||
mels = torch.FloatTensor(mels)
|
||||
return x_input, mels, y_coarse
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,56 @@
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
class ResStack(nn.Module):
|
||||
def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]):
|
||||
super().__init__()
|
||||
resstack = []
|
||||
for dilation in dilations:
|
||||
resstack += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(dilation),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(padding),
|
||||
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
||||
]
|
||||
self.resstack = nn.Sequential(*resstack)
|
||||
|
||||
self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.shortcut(x)
|
||||
x2 = self.resstack(x)
|
||||
return x1 + x2
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.shortcut, "weight")
|
||||
remove_parametrizations(self.resstack[2], "weight")
|
||||
remove_parametrizations(self.resstack[5], "weight")
|
||||
remove_parametrizations(self.resstack[8], "weight")
|
||||
remove_parametrizations(self.resstack[11], "weight")
|
||||
remove_parametrizations(self.resstack[14], "weight")
|
||||
remove_parametrizations(self.resstack[17], "weight")
|
||||
|
||||
|
||||
class MRF(nn.Module):
|
||||
def __init__(self, kernels, channel, dilations=[1, 3, 5]): # # pylint: disable=dangerous-default-value
|
||||
super().__init__()
|
||||
self.resblock1 = ResStack(kernels[0], channel, 0, dilations)
|
||||
self.resblock2 = ResStack(kernels[1], channel, 6, dilations)
|
||||
self.resblock3 = ResStack(kernels[2], channel, 12, dilations)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.resblock1(x)
|
||||
x2 = self.resblock2(x)
|
||||
x3 = self.resblock3(x)
|
||||
return x1 + x2 + x3
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.resblock1.remove_weight_norm()
|
||||
self.resblock2.remove_weight_norm()
|
||||
self.resblock3.remove_weight_norm()
|
||||
@@ -0,0 +1,368 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
|
||||
#################################
|
||||
# GENERATOR LOSSES
|
||||
#################################
|
||||
|
||||
|
||||
class STFTLoss(nn.Module):
|
||||
"""STFT loss. Input generate and real waveforms are converted
|
||||
to spectrograms compared with L1 and Spectral convergence losses.
|
||||
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||
|
||||
def __init__(self, n_fft, hop_length, win_length):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.stft = TorchSTFT(n_fft, hop_length, win_length)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
y_hat_M = self.stft(y_hat)
|
||||
y_M = self.stft(y)
|
||||
# magnitude loss
|
||||
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
|
||||
# spectral convergence loss
|
||||
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||
"""Multi-scale STFT loss. Input generate and real waveforms are converted
|
||||
to spectrograms compared with L1 and Spectral convergence losses.
|
||||
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||
|
||||
def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)):
|
||||
super().__init__()
|
||||
self.loss_funcs = torch.nn.ModuleList()
|
||||
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
|
||||
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
N = len(self.loss_funcs)
|
||||
loss_sc = 0
|
||||
loss_mag = 0
|
||||
for f in self.loss_funcs:
|
||||
lm, lsc = f(y_hat, y)
|
||||
loss_mag += lm
|
||||
loss_sc += lsc
|
||||
loss_sc /= N
|
||||
loss_mag /= N
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class L1SpecLoss(nn.Module):
|
||||
"""L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""
|
||||
|
||||
def __init__(
|
||||
self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True
|
||||
):
|
||||
super().__init__()
|
||||
self.use_mel = use_mel
|
||||
self.stft = TorchSTFT(
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
sample_rate=sample_rate,
|
||||
mel_fmin=mel_fmin,
|
||||
mel_fmax=mel_fmax,
|
||||
n_mels=n_mels,
|
||||
use_mel=use_mel,
|
||||
)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
y_hat_M = self.stft(y_hat)
|
||||
y_M = self.stft(y)
|
||||
# magnitude loss
|
||||
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
|
||||
return loss_mag
|
||||
|
||||
|
||||
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||
"""Multiscale STFT loss for multi band model outputs.
|
||||
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, y_hat, y):
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y = y.view(-1, 1, y.shape[2])
|
||||
return super().forward(y_hat.squeeze(1), y.squeeze(1))
|
||||
|
||||
|
||||
class MSEGLoss(nn.Module):
|
||||
"""Mean Squared Generator Loss"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_real):
|
||||
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
|
||||
return loss_fake
|
||||
|
||||
|
||||
class HingeGLoss(nn.Module):
|
||||
"""Hinge Discriminator Loss"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_real):
|
||||
# TODO: this might be wrong
|
||||
loss_fake = torch.mean(F.relu(1.0 - score_real))
|
||||
return loss_fake
|
||||
|
||||
|
||||
##################################
|
||||
# DISCRIMINATOR LOSSES
|
||||
##################################
|
||||
|
||||
|
||||
class MSEDLoss(nn.Module):
|
||||
"""Mean Squared Discriminator Loss"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_func = nn.MSELoss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape))
|
||||
loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
||||
class HingeDLoss(nn.Module):
|
||||
"""Hinge Discriminator Loss"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(F.relu(1.0 - score_real))
|
||||
loss_fake = torch.mean(F.relu(1.0 + score_fake))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
||||
class MelganFeatureLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_func = nn.L1Loss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, fake_feats, real_feats):
|
||||
loss_feats = 0
|
||||
num_feats = 0
|
||||
for idx, _ in enumerate(fake_feats):
|
||||
for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]):
|
||||
loss_feats += self.loss_func(fake_feat, real_feat)
|
||||
num_feats += 1
|
||||
loss_feats = loss_feats / num_feats
|
||||
return loss_feats
|
||||
|
||||
|
||||
#####################################
|
||||
# LOSS WRAPPERS
|
||||
#####################################
|
||||
|
||||
|
||||
def _apply_G_adv_loss(scores_fake, loss_func):
|
||||
"""Compute G adversarial loss function
|
||||
and normalize values"""
|
||||
adv_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
for score_fake in scores_fake:
|
||||
fake_loss = loss_func(score_fake)
|
||||
adv_loss += fake_loss
|
||||
adv_loss /= len(scores_fake)
|
||||
else:
|
||||
fake_loss = loss_func(scores_fake)
|
||||
adv_loss = fake_loss
|
||||
return adv_loss
|
||||
|
||||
|
||||
def _apply_D_loss(scores_fake, scores_real, loss_func):
|
||||
"""Compute D loss func and normalize loss values"""
|
||||
loss = 0
|
||||
real_loss = 0
|
||||
fake_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
# multi-scale loss
|
||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
||||
total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real)
|
||||
loss += total_loss
|
||||
real_loss += real_loss_
|
||||
fake_loss += fake_loss_
|
||||
# normalize loss values with number of scales (discriminators)
|
||||
loss /= len(scores_fake)
|
||||
real_loss /= len(scores_real)
|
||||
fake_loss /= len(scores_fake)
|
||||
else:
|
||||
# single scale loss
|
||||
total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real)
|
||||
loss = total_loss
|
||||
return loss, real_loss, fake_loss
|
||||
|
||||
|
||||
##################################
|
||||
# MODEL LOSSES
|
||||
##################################
|
||||
|
||||
|
||||
class GeneratorLoss(nn.Module):
|
||||
"""Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes
|
||||
losses. It allows to experiment with different combinations of loss functions with different models by just
|
||||
changing configurations.
|
||||
|
||||
Args:
|
||||
C (AttrDict): model configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, C):
|
||||
super().__init__()
|
||||
assert not (
|
||||
C.use_mse_gan_loss and C.use_hinge_gan_loss
|
||||
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False
|
||||
self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False
|
||||
self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False
|
||||
self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False
|
||||
|
||||
self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0
|
||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0
|
||||
self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0
|
||||
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0
|
||||
self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0
|
||||
self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0
|
||||
|
||||
if C.use_stft_loss:
|
||||
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
|
||||
if C.use_subband_stft_loss:
|
||||
self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params)
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEGLoss()
|
||||
if C.use_hinge_gan_loss:
|
||||
self.hinge_loss = HingeGLoss()
|
||||
if C.use_feat_match_loss:
|
||||
self.feat_match_loss = MelganFeatureLoss()
|
||||
if C.use_l1_spec_loss:
|
||||
assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"]
|
||||
self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)
|
||||
|
||||
def forward(
|
||||
self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None
|
||||
):
|
||||
gen_loss = 0
|
||||
adv_loss = 0
|
||||
return_dict = {}
|
||||
|
||||
# STFT Loss
|
||||
if self.use_stft_loss:
|
||||
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1))
|
||||
return_dict["G_stft_loss_mg"] = stft_loss_mg
|
||||
return_dict["G_stft_loss_sc"] = stft_loss_sc
|
||||
gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
||||
|
||||
# L1 Spec loss
|
||||
if self.use_l1_spec_loss:
|
||||
l1_spec_loss = self.l1_spec_loss(y_hat, y)
|
||||
return_dict["G_l1_spec_loss"] = l1_spec_loss
|
||||
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
|
||||
|
||||
# subband STFT Loss
|
||||
if self.use_subband_stft_loss:
|
||||
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
|
||||
return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg
|
||||
return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc
|
||||
gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
|
||||
|
||||
# multiscale MSE adversarial loss
|
||||
if self.use_mse_gan_loss and scores_fake is not None:
|
||||
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
|
||||
return_dict["G_mse_fake_loss"] = mse_fake_loss
|
||||
adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss
|
||||
|
||||
# multiscale Hinge adversarial loss
|
||||
if self.use_hinge_gan_loss and not scores_fake is not None:
|
||||
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
|
||||
return_dict["G_hinge_fake_loss"] = hinge_fake_loss
|
||||
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
|
||||
|
||||
# Feature Matching Loss
|
||||
if self.use_feat_match_loss and not feats_fake is None:
|
||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||
return_dict["G_feat_match_loss"] = feat_match_loss
|
||||
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
||||
return_dict["loss"] = gen_loss + adv_loss
|
||||
return_dict["G_gen_loss"] = gen_loss
|
||||
return_dict["G_adv_loss"] = adv_loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
"""Like ```GeneratorLoss```"""
|
||||
|
||||
def __init__(self, C):
|
||||
super().__init__()
|
||||
assert not (
|
||||
C.use_mse_gan_loss and C.use_hinge_gan_loss
|
||||
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEDLoss()
|
||||
if C.use_hinge_gan_loss:
|
||||
self.hinge_loss = HingeDLoss()
|
||||
|
||||
def forward(self, scores_fake, scores_real):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
||||
if self.use_mse_gan_loss:
|
||||
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(
|
||||
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss
|
||||
)
|
||||
return_dict["D_mse_gan_loss"] = mse_D_loss
|
||||
return_dict["D_mse_gan_real_loss"] = mse_D_real_loss
|
||||
return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss
|
||||
loss += mse_D_loss
|
||||
|
||||
if self.use_hinge_gan_loss:
|
||||
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(
|
||||
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss
|
||||
)
|
||||
return_dict["D_hinge_gan_loss"] = hinge_D_loss
|
||||
return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss
|
||||
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
|
||||
loss += hinge_D_loss
|
||||
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class WaveRNNLoss(nn.Module):
|
||||
def __init__(self, wave_rnn_mode: Union[str, int]):
|
||||
super().__init__()
|
||||
if wave_rnn_mode == "mold":
|
||||
self.loss_func = discretized_mix_logistic_loss
|
||||
elif wave_rnn_mode == "gauss":
|
||||
self.loss_func = gaussian_loss
|
||||
elif isinstance(wave_rnn_mode, int):
|
||||
self.loss_func = torch.nn.CrossEntropyLoss()
|
||||
else:
|
||||
raise ValueError(" [!] Unknown mode for Wavernn.")
|
||||
|
||||
def forward(self, y_hat, y) -> Dict:
|
||||
loss = self.loss_func(y_hat, y)
|
||||
return {"loss": loss}
|
||||
@@ -0,0 +1,198 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class KernelPredictor(torch.nn.Module):
|
||||
"""Kernel predictor for the location-variable convolutions"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int):
|
||||
kpnet_
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
|
||||
l_b = conv_out_channels * conv_layers
|
||||
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
self.input_conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_conv = torch.nn.Sequential(
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True)
|
||||
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
Returns:
|
||||
"""
|
||||
batch, _, cond_length = c.shape
|
||||
|
||||
c = self.input_conv(c)
|
||||
c = c + self.residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
|
||||
kernels = k.contiguous().view(
|
||||
batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length
|
||||
)
|
||||
bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length)
|
||||
return kernels, bias
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
"""the location-variable convolutions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
upsample_ratio,
|
||||
conv_layers=4,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = conv_layers
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.convs = torch.nn.ModuleList()
|
||||
|
||||
self.upsample = torch.nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=upsample_ratio * 2,
|
||||
stride=upsample_ratio,
|
||||
padding=upsample_ratio // 2 + upsample_ratio % 2,
|
||||
output_padding=upsample_ratio % 2,
|
||||
)
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=conv_layers,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout,
|
||||
)
|
||||
|
||||
for i in range(conv_layers):
|
||||
padding = (3**i) * int((conv_kernel_size - 1) / 2)
|
||||
conv = torch.nn.Conv1d(
|
||||
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i
|
||||
)
|
||||
|
||||
self.convs.append(conv)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
"""
|
||||
in_channels = x.shape[1]
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = self.upsample(x)
|
||||
|
||||
for i in range(self.conv_layers):
|
||||
y = F.leaky_relu(x, 0.2)
|
||||
y = self.convs[i](y)
|
||||
y = F.leaky_relu(y, 0.2)
|
||||
|
||||
k = kernels[:, i, :, :, :, :]
|
||||
b = bias[:, i, :, :]
|
||||
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
|
||||
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def location_variable_convolution(x, kernel, bias, dilation, hop_size):
|
||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
|
||||
assert in_length == (
|
||||
kernel_length * hop_size
|
||||
), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
x = x.unfold(
|
||||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o + bias.unsqueeze(-1).unsqueeze(-1)
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
return o
|
||||
@@ -0,0 +1,43 @@
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
class ResidualStack(nn.Module):
|
||||
def __init__(self, channels, num_res_blocks, kernel_size):
|
||||
super().__init__()
|
||||
|
||||
assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
|
||||
base_padding = (kernel_size - 1) // 2
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for idx in range(num_res_blocks):
|
||||
layer_kernel_size = kernel_size
|
||||
layer_dilation = layer_kernel_size**idx
|
||||
layer_padding = base_padding * layer_dilation
|
||||
self.blocks += [
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(
|
||||
nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True)
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
|
||||
)
|
||||
]
|
||||
|
||||
self.shortcuts = nn.ModuleList(
|
||||
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
x = shortcut(x) + block(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
remove_parametrizations(block[2], "weight")
|
||||
remove_parametrizations(block[4], "weight")
|
||||
remove_parametrizations(shortcut, "weight")
|
||||
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class ResidualBlock(torch.nn.Module):
|
||||
"""Residual block module in WaveNet."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
dropout=0.0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
# no future time stamps available
|
||||
if use_causal_conv:
|
||||
padding = (kernel_size - 1) * dilation
|
||||
else:
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.use_causal_conv = use_causal_conv
|
||||
|
||||
# dilation conv
|
||||
self.conv = torch.nn.Conv1d(
|
||||
res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias
|
||||
)
|
||||
|
||||
# local conditioning
|
||||
if aux_channels > 0:
|
||||
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
|
||||
else:
|
||||
self.conv1x1_aux = None
|
||||
|
||||
# conv output is split into two groups
|
||||
gate_out_channels = gate_channels // 2
|
||||
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
|
||||
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""
|
||||
x: B x D_res x T
|
||||
c: B x D_aux x T
|
||||
"""
|
||||
residual = x
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.conv(x)
|
||||
|
||||
# remove future time steps if use_causal_conv conv
|
||||
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
|
||||
|
||||
# split into two part for gated activation
|
||||
splitdim = 1
|
||||
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
||||
|
||||
# local conditioning
|
||||
if c is not None:
|
||||
assert self.conv1x1_aux is not None
|
||||
c = self.conv1x1_aux(c)
|
||||
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
||||
xa, xb = xa + ca, xb + cb
|
||||
|
||||
x = torch.tanh(xa) * torch.sigmoid(xb)
|
||||
|
||||
# for skip connection
|
||||
s = self.conv1x1_skip(x)
|
||||
|
||||
# for residual connection
|
||||
x = (self.conv1x1_out(x) + residual) * (0.5**2)
|
||||
|
||||
return x, s
|
||||
@@ -0,0 +1,53 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy import signal as sig
|
||||
|
||||
|
||||
# adapted from
|
||||
# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan
|
||||
class PQMF(torch.nn.Module):
|
||||
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
|
||||
super().__init__()
|
||||
|
||||
self.N = N
|
||||
self.taps = taps
|
||||
self.cutoff = cutoff
|
||||
self.beta = beta
|
||||
|
||||
QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta))
|
||||
H = np.zeros((N, len(QMF)))
|
||||
G = np.zeros((N, len(QMF)))
|
||||
for k in range(N):
|
||||
constant_factor = (
|
||||
(2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2))
|
||||
) # TODO: (taps - 1) -> taps
|
||||
phase = (-1) ** k * np.pi / 4
|
||||
H[k] = 2 * QMF * np.cos(constant_factor + phase)
|
||||
|
||||
G[k] = 2 * QMF * np.cos(constant_factor - phase)
|
||||
|
||||
H = torch.from_numpy(H[:, None, :]).float()
|
||||
G = torch.from_numpy(G[None, :, :]).float()
|
||||
|
||||
self.register_buffer("H", H)
|
||||
self.register_buffer("G", G)
|
||||
|
||||
updown_filter = torch.zeros((N, N, N)).float()
|
||||
for k in range(N):
|
||||
updown_filter[k, k, 0] = 1.0
|
||||
self.register_buffer("updown_filter", updown_filter)
|
||||
self.N = N
|
||||
|
||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.analysis(x)
|
||||
|
||||
def analysis(self, x):
|
||||
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
|
||||
|
||||
def synthesis(self, x):
|
||||
x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N)
|
||||
x = F.conv1d(x, self.G, padding=self.taps // 2)
|
||||
return x
|
||||
@@ -0,0 +1,640 @@
|
||||
0.0000000e+000
|
||||
-5.5252865e-004
|
||||
-5.6176926e-004
|
||||
-4.9475181e-004
|
||||
-4.8752280e-004
|
||||
-4.8937912e-004
|
||||
-5.0407143e-004
|
||||
-5.2265643e-004
|
||||
-5.4665656e-004
|
||||
-5.6778026e-004
|
||||
-5.8709305e-004
|
||||
-6.1327474e-004
|
||||
-6.3124935e-004
|
||||
-6.5403334e-004
|
||||
-6.7776908e-004
|
||||
-6.9416146e-004
|
||||
-7.1577365e-004
|
||||
-7.2550431e-004
|
||||
-7.4409419e-004
|
||||
-7.4905981e-004
|
||||
-7.6813719e-004
|
||||
-7.7248486e-004
|
||||
-7.8343323e-004
|
||||
-7.7798695e-004
|
||||
-7.8036647e-004
|
||||
-7.8014496e-004
|
||||
-7.7579773e-004
|
||||
-7.6307936e-004
|
||||
-7.5300014e-004
|
||||
-7.3193572e-004
|
||||
-7.2153920e-004
|
||||
-6.9179375e-004
|
||||
-6.6504151e-004
|
||||
-6.3415949e-004
|
||||
-5.9461189e-004
|
||||
-5.5645764e-004
|
||||
-5.1455722e-004
|
||||
-4.6063255e-004
|
||||
-4.0951215e-004
|
||||
-3.5011759e-004
|
||||
-2.8969812e-004
|
||||
-2.0983373e-004
|
||||
-1.4463809e-004
|
||||
-6.1733441e-005
|
||||
1.3494974e-005
|
||||
1.0943831e-004
|
||||
2.0430171e-004
|
||||
2.9495311e-004
|
||||
4.0265402e-004
|
||||
5.1073885e-004
|
||||
6.2393761e-004
|
||||
7.4580259e-004
|
||||
8.6084433e-004
|
||||
9.8859883e-004
|
||||
1.1250155e-003
|
||||
1.2577885e-003
|
||||
1.3902495e-003
|
||||
1.5443220e-003
|
||||
1.6868083e-003
|
||||
1.8348265e-003
|
||||
1.9841141e-003
|
||||
2.1461584e-003
|
||||
2.3017255e-003
|
||||
2.4625617e-003
|
||||
2.6201759e-003
|
||||
2.7870464e-003
|
||||
2.9469448e-003
|
||||
3.1125421e-003
|
||||
3.2739613e-003
|
||||
3.4418874e-003
|
||||
3.6008268e-003
|
||||
3.7603923e-003
|
||||
3.9207432e-003
|
||||
4.0819753e-003
|
||||
4.2264269e-003
|
||||
4.3730720e-003
|
||||
4.5209853e-003
|
||||
4.6606461e-003
|
||||
4.7932561e-003
|
||||
4.9137604e-003
|
||||
5.0393023e-003
|
||||
5.1407354e-003
|
||||
5.2461166e-003
|
||||
5.3471681e-003
|
||||
5.4196776e-003
|
||||
5.4876040e-003
|
||||
5.5475715e-003
|
||||
5.5938023e-003
|
||||
5.6220643e-003
|
||||
5.6455197e-003
|
||||
5.6389200e-003
|
||||
5.6266114e-003
|
||||
5.5917129e-003
|
||||
5.5404364e-003
|
||||
5.4753783e-003
|
||||
5.3838976e-003
|
||||
5.2715759e-003
|
||||
5.1382275e-003
|
||||
4.9839688e-003
|
||||
4.8109469e-003
|
||||
4.6039530e-003
|
||||
4.3801862e-003
|
||||
4.1251642e-003
|
||||
3.8456408e-003
|
||||
3.5401247e-003
|
||||
3.2091886e-003
|
||||
2.8446758e-003
|
||||
2.4508540e-003
|
||||
2.0274176e-003
|
||||
1.5784683e-003
|
||||
1.0902329e-003
|
||||
5.8322642e-004
|
||||
2.7604519e-005
|
||||
-5.4642809e-004
|
||||
-1.1568136e-003
|
||||
-1.8039473e-003
|
||||
-2.4826724e-003
|
||||
-3.1933778e-003
|
||||
-3.9401124e-003
|
||||
-4.7222596e-003
|
||||
-5.5337211e-003
|
||||
-6.3792293e-003
|
||||
-7.2615817e-003
|
||||
-8.1798233e-003
|
||||
-9.1325330e-003
|
||||
-1.0115022e-002
|
||||
-1.1131555e-002
|
||||
-1.2185000e-002
|
||||
-1.3271822e-002
|
||||
-1.4390467e-002
|
||||
-1.5540555e-002
|
||||
-1.6732471e-002
|
||||
-1.7943338e-002
|
||||
-1.9187243e-002
|
||||
-2.0453179e-002
|
||||
-2.1746755e-002
|
||||
-2.3068017e-002
|
||||
-2.4416099e-002
|
||||
-2.5787585e-002
|
||||
-2.7185943e-002
|
||||
-2.8607217e-002
|
||||
-3.0050266e-002
|
||||
-3.1501761e-002
|
||||
-3.2975408e-002
|
||||
-3.4462095e-002
|
||||
-3.5969756e-002
|
||||
-3.7481285e-002
|
||||
-3.9005368e-002
|
||||
-4.0534917e-002
|
||||
-4.2064909e-002
|
||||
-4.3609754e-002
|
||||
-4.5148841e-002
|
||||
-4.6684303e-002
|
||||
-4.8216572e-002
|
||||
-4.9738576e-002
|
||||
-5.1255616e-002
|
||||
-5.2763075e-002
|
||||
-5.4245277e-002
|
||||
-5.5717365e-002
|
||||
-5.7161645e-002
|
||||
-5.8591568e-002
|
||||
-5.9983748e-002
|
||||
-6.1345517e-002
|
||||
-6.2685781e-002
|
||||
-6.3971590e-002
|
||||
-6.5224711e-002
|
||||
-6.6436751e-002
|
||||
-6.7607599e-002
|
||||
-6.8704383e-002
|
||||
-6.9763024e-002
|
||||
-7.0762871e-002
|
||||
-7.1700267e-002
|
||||
-7.2568258e-002
|
||||
-7.3362026e-002
|
||||
-7.4100364e-002
|
||||
-7.4745256e-002
|
||||
-7.5313734e-002
|
||||
-7.5800836e-002
|
||||
-7.6199248e-002
|
||||
-7.6499217e-002
|
||||
-7.6709349e-002
|
||||
-7.6817398e-002
|
||||
-7.6823001e-002
|
||||
-7.6720492e-002
|
||||
-7.6505072e-002
|
||||
-7.6174832e-002
|
||||
-7.5730576e-002
|
||||
-7.5157626e-002
|
||||
-7.4466439e-002
|
||||
-7.3640601e-002
|
||||
-7.2677464e-002
|
||||
-7.1582636e-002
|
||||
-7.0353307e-002
|
||||
-6.8966401e-002
|
||||
-6.7452502e-002
|
||||
-6.5769067e-002
|
||||
-6.3944481e-002
|
||||
-6.1960278e-002
|
||||
-5.9816657e-002
|
||||
-5.7515269e-002
|
||||
-5.5046003e-002
|
||||
-5.2409382e-002
|
||||
-4.9597868e-002
|
||||
-4.6630331e-002
|
||||
-4.3476878e-002
|
||||
-4.0145828e-002
|
||||
-3.6641812e-002
|
||||
-3.2958393e-002
|
||||
-2.9082401e-002
|
||||
-2.5030756e-002
|
||||
-2.0799707e-002
|
||||
-1.6370126e-002
|
||||
-1.1762383e-002
|
||||
-6.9636862e-003
|
||||
-1.9765601e-003
|
||||
3.2086897e-003
|
||||
8.5711749e-003
|
||||
1.4128883e-002
|
||||
1.9883413e-002
|
||||
2.5822729e-002
|
||||
3.1953127e-002
|
||||
3.8277657e-002
|
||||
4.4780682e-002
|
||||
5.1480418e-002
|
||||
5.8370533e-002
|
||||
6.5440985e-002
|
||||
7.2694330e-002
|
||||
8.0137293e-002
|
||||
8.7754754e-002
|
||||
9.5553335e-002
|
||||
1.0353295e-001
|
||||
1.1168269e-001
|
||||
1.2000780e-001
|
||||
1.2850029e-001
|
||||
1.3715518e-001
|
||||
1.4597665e-001
|
||||
1.5496071e-001
|
||||
1.6409589e-001
|
||||
1.7338082e-001
|
||||
1.8281725e-001
|
||||
1.9239667e-001
|
||||
2.0212502e-001
|
||||
2.1197359e-001
|
||||
2.2196527e-001
|
||||
2.3206909e-001
|
||||
2.4230169e-001
|
||||
2.5264803e-001
|
||||
2.6310533e-001
|
||||
2.7366340e-001
|
||||
2.8432142e-001
|
||||
2.9507167e-001
|
||||
3.0590986e-001
|
||||
3.1682789e-001
|
||||
3.2781137e-001
|
||||
3.3887227e-001
|
||||
3.4999141e-001
|
||||
3.6115899e-001
|
||||
3.7237955e-001
|
||||
3.8363500e-001
|
||||
3.9492118e-001
|
||||
4.0623177e-001
|
||||
4.1756969e-001
|
||||
4.2891199e-001
|
||||
4.4025538e-001
|
||||
4.5159965e-001
|
||||
4.6293081e-001
|
||||
4.7424532e-001
|
||||
4.8552531e-001
|
||||
4.9677083e-001
|
||||
5.0798175e-001
|
||||
5.1912350e-001
|
||||
5.3022409e-001
|
||||
5.4125534e-001
|
||||
5.5220513e-001
|
||||
5.6307891e-001
|
||||
5.7385241e-001
|
||||
5.8454032e-001
|
||||
5.9511231e-001
|
||||
6.0557835e-001
|
||||
6.1591099e-001
|
||||
6.2612427e-001
|
||||
6.3619801e-001
|
||||
6.4612697e-001
|
||||
6.5590163e-001
|
||||
6.6551399e-001
|
||||
6.7496632e-001
|
||||
6.8423533e-001
|
||||
6.9332824e-001
|
||||
7.0223887e-001
|
||||
7.1094104e-001
|
||||
7.1944626e-001
|
||||
7.2774489e-001
|
||||
7.3582118e-001
|
||||
7.4368279e-001
|
||||
7.5131375e-001
|
||||
7.5870808e-001
|
||||
7.6586749e-001
|
||||
7.7277809e-001
|
||||
7.7942875e-001
|
||||
7.8583531e-001
|
||||
7.9197358e-001
|
||||
7.9784664e-001
|
||||
8.0344858e-001
|
||||
8.0876950e-001
|
||||
8.1381913e-001
|
||||
8.1857760e-001
|
||||
8.2304199e-001
|
||||
8.2722753e-001
|
||||
8.3110385e-001
|
||||
8.3469374e-001
|
||||
8.3797173e-001
|
||||
8.4095414e-001
|
||||
8.4362383e-001
|
||||
8.4598185e-001
|
||||
8.4803158e-001
|
||||
8.4978052e-001
|
||||
8.5119715e-001
|
||||
8.5230470e-001
|
||||
8.5310209e-001
|
||||
8.5357206e-001
|
||||
8.5373856e-001
|
||||
8.5357206e-001
|
||||
8.5310209e-001
|
||||
8.5230470e-001
|
||||
8.5119715e-001
|
||||
8.4978052e-001
|
||||
8.4803158e-001
|
||||
8.4598185e-001
|
||||
8.4362383e-001
|
||||
8.4095414e-001
|
||||
8.3797173e-001
|
||||
8.3469374e-001
|
||||
8.3110385e-001
|
||||
8.2722753e-001
|
||||
8.2304199e-001
|
||||
8.1857760e-001
|
||||
8.1381913e-001
|
||||
8.0876950e-001
|
||||
8.0344858e-001
|
||||
7.9784664e-001
|
||||
7.9197358e-001
|
||||
7.8583531e-001
|
||||
7.7942875e-001
|
||||
7.7277809e-001
|
||||
7.6586749e-001
|
||||
7.5870808e-001
|
||||
7.5131375e-001
|
||||
7.4368279e-001
|
||||
7.3582118e-001
|
||||
7.2774489e-001
|
||||
7.1944626e-001
|
||||
7.1094104e-001
|
||||
7.0223887e-001
|
||||
6.9332824e-001
|
||||
6.8423533e-001
|
||||
6.7496632e-001
|
||||
6.6551399e-001
|
||||
6.5590163e-001
|
||||
6.4612697e-001
|
||||
6.3619801e-001
|
||||
6.2612427e-001
|
||||
6.1591099e-001
|
||||
6.0557835e-001
|
||||
5.9511231e-001
|
||||
5.8454032e-001
|
||||
5.7385241e-001
|
||||
5.6307891e-001
|
||||
5.5220513e-001
|
||||
5.4125534e-001
|
||||
5.3022409e-001
|
||||
5.1912350e-001
|
||||
5.0798175e-001
|
||||
4.9677083e-001
|
||||
4.8552531e-001
|
||||
4.7424532e-001
|
||||
4.6293081e-001
|
||||
4.5159965e-001
|
||||
4.4025538e-001
|
||||
4.2891199e-001
|
||||
4.1756969e-001
|
||||
4.0623177e-001
|
||||
3.9492118e-001
|
||||
3.8363500e-001
|
||||
3.7237955e-001
|
||||
3.6115899e-001
|
||||
3.4999141e-001
|
||||
3.3887227e-001
|
||||
3.2781137e-001
|
||||
3.1682789e-001
|
||||
3.0590986e-001
|
||||
2.9507167e-001
|
||||
2.8432142e-001
|
||||
2.7366340e-001
|
||||
2.6310533e-001
|
||||
2.5264803e-001
|
||||
2.4230169e-001
|
||||
2.3206909e-001
|
||||
2.2196527e-001
|
||||
2.1197359e-001
|
||||
2.0212502e-001
|
||||
1.9239667e-001
|
||||
1.8281725e-001
|
||||
1.7338082e-001
|
||||
1.6409589e-001
|
||||
1.5496071e-001
|
||||
1.4597665e-001
|
||||
1.3715518e-001
|
||||
1.2850029e-001
|
||||
1.2000780e-001
|
||||
1.1168269e-001
|
||||
1.0353295e-001
|
||||
9.5553335e-002
|
||||
8.7754754e-002
|
||||
8.0137293e-002
|
||||
7.2694330e-002
|
||||
6.5440985e-002
|
||||
5.8370533e-002
|
||||
5.1480418e-002
|
||||
4.4780682e-002
|
||||
3.8277657e-002
|
||||
3.1953127e-002
|
||||
2.5822729e-002
|
||||
1.9883413e-002
|
||||
1.4128883e-002
|
||||
8.5711749e-003
|
||||
3.2086897e-003
|
||||
-1.9765601e-003
|
||||
-6.9636862e-003
|
||||
-1.1762383e-002
|
||||
-1.6370126e-002
|
||||
-2.0799707e-002
|
||||
-2.5030756e-002
|
||||
-2.9082401e-002
|
||||
-3.2958393e-002
|
||||
-3.6641812e-002
|
||||
-4.0145828e-002
|
||||
-4.3476878e-002
|
||||
-4.6630331e-002
|
||||
-4.9597868e-002
|
||||
-5.2409382e-002
|
||||
-5.5046003e-002
|
||||
-5.7515269e-002
|
||||
-5.9816657e-002
|
||||
-6.1960278e-002
|
||||
-6.3944481e-002
|
||||
-6.5769067e-002
|
||||
-6.7452502e-002
|
||||
-6.8966401e-002
|
||||
-7.0353307e-002
|
||||
-7.1582636e-002
|
||||
-7.2677464e-002
|
||||
-7.3640601e-002
|
||||
-7.4466439e-002
|
||||
-7.5157626e-002
|
||||
-7.5730576e-002
|
||||
-7.6174832e-002
|
||||
-7.6505072e-002
|
||||
-7.6720492e-002
|
||||
-7.6823001e-002
|
||||
-7.6817398e-002
|
||||
-7.6709349e-002
|
||||
-7.6499217e-002
|
||||
-7.6199248e-002
|
||||
-7.5800836e-002
|
||||
-7.5313734e-002
|
||||
-7.4745256e-002
|
||||
-7.4100364e-002
|
||||
-7.3362026e-002
|
||||
-7.2568258e-002
|
||||
-7.1700267e-002
|
||||
-7.0762871e-002
|
||||
-6.9763024e-002
|
||||
-6.8704383e-002
|
||||
-6.7607599e-002
|
||||
-6.6436751e-002
|
||||
-6.5224711e-002
|
||||
-6.3971590e-002
|
||||
-6.2685781e-002
|
||||
-6.1345517e-002
|
||||
-5.9983748e-002
|
||||
-5.8591568e-002
|
||||
-5.7161645e-002
|
||||
-5.5717365e-002
|
||||
-5.4245277e-002
|
||||
-5.2763075e-002
|
||||
-5.1255616e-002
|
||||
-4.9738576e-002
|
||||
-4.8216572e-002
|
||||
-4.6684303e-002
|
||||
-4.5148841e-002
|
||||
-4.3609754e-002
|
||||
-4.2064909e-002
|
||||
-4.0534917e-002
|
||||
-3.9005368e-002
|
||||
-3.7481285e-002
|
||||
-3.5969756e-002
|
||||
-3.4462095e-002
|
||||
-3.2975408e-002
|
||||
-3.1501761e-002
|
||||
-3.0050266e-002
|
||||
-2.8607217e-002
|
||||
-2.7185943e-002
|
||||
-2.5787585e-002
|
||||
-2.4416099e-002
|
||||
-2.3068017e-002
|
||||
-2.1746755e-002
|
||||
-2.0453179e-002
|
||||
-1.9187243e-002
|
||||
-1.7943338e-002
|
||||
-1.6732471e-002
|
||||
-1.5540555e-002
|
||||
-1.4390467e-002
|
||||
-1.3271822e-002
|
||||
-1.2185000e-002
|
||||
-1.1131555e-002
|
||||
-1.0115022e-002
|
||||
-9.1325330e-003
|
||||
-8.1798233e-003
|
||||
-7.2615817e-003
|
||||
-6.3792293e-003
|
||||
-5.5337211e-003
|
||||
-4.7222596e-003
|
||||
-3.9401124e-003
|
||||
-3.1933778e-003
|
||||
-2.4826724e-003
|
||||
-1.8039473e-003
|
||||
-1.1568136e-003
|
||||
-5.4642809e-004
|
||||
2.7604519e-005
|
||||
5.8322642e-004
|
||||
1.0902329e-003
|
||||
1.5784683e-003
|
||||
2.0274176e-003
|
||||
2.4508540e-003
|
||||
2.8446758e-003
|
||||
3.2091886e-003
|
||||
3.5401247e-003
|
||||
3.8456408e-003
|
||||
4.1251642e-003
|
||||
4.3801862e-003
|
||||
4.6039530e-003
|
||||
4.8109469e-003
|
||||
4.9839688e-003
|
||||
5.1382275e-003
|
||||
5.2715759e-003
|
||||
5.3838976e-003
|
||||
5.4753783e-003
|
||||
5.5404364e-003
|
||||
5.5917129e-003
|
||||
5.6266114e-003
|
||||
5.6389200e-003
|
||||
5.6455197e-003
|
||||
5.6220643e-003
|
||||
5.5938023e-003
|
||||
5.5475715e-003
|
||||
5.4876040e-003
|
||||
5.4196776e-003
|
||||
5.3471681e-003
|
||||
5.2461166e-003
|
||||
5.1407354e-003
|
||||
5.0393023e-003
|
||||
4.9137604e-003
|
||||
4.7932561e-003
|
||||
4.6606461e-003
|
||||
4.5209853e-003
|
||||
4.3730720e-003
|
||||
4.2264269e-003
|
||||
4.0819753e-003
|
||||
3.9207432e-003
|
||||
3.7603923e-003
|
||||
3.6008268e-003
|
||||
3.4418874e-003
|
||||
3.2739613e-003
|
||||
3.1125421e-003
|
||||
2.9469448e-003
|
||||
2.7870464e-003
|
||||
2.6201759e-003
|
||||
2.4625617e-003
|
||||
2.3017255e-003
|
||||
2.1461584e-003
|
||||
1.9841141e-003
|
||||
1.8348265e-003
|
||||
1.6868083e-003
|
||||
1.5443220e-003
|
||||
1.3902495e-003
|
||||
1.2577885e-003
|
||||
1.1250155e-003
|
||||
9.8859883e-004
|
||||
8.6084433e-004
|
||||
7.4580259e-004
|
||||
6.2393761e-004
|
||||
5.1073885e-004
|
||||
4.0265402e-004
|
||||
2.9495311e-004
|
||||
2.0430171e-004
|
||||
1.0943831e-004
|
||||
1.3494974e-005
|
||||
-6.1733441e-005
|
||||
-1.4463809e-004
|
||||
-2.0983373e-004
|
||||
-2.8969812e-004
|
||||
-3.5011759e-004
|
||||
-4.0951215e-004
|
||||
-4.6063255e-004
|
||||
-5.1455722e-004
|
||||
-5.5645764e-004
|
||||
-5.9461189e-004
|
||||
-6.3415949e-004
|
||||
-6.6504151e-004
|
||||
-6.9179375e-004
|
||||
-7.2153920e-004
|
||||
-7.3193572e-004
|
||||
-7.5300014e-004
|
||||
-7.6307936e-004
|
||||
-7.7579773e-004
|
||||
-7.8014496e-004
|
||||
-7.8036647e-004
|
||||
-7.7798695e-004
|
||||
-7.8343323e-004
|
||||
-7.7248486e-004
|
||||
-7.6813719e-004
|
||||
-7.4905981e-004
|
||||
-7.4409419e-004
|
||||
-7.2550431e-004
|
||||
-7.1577365e-004
|
||||
-6.9416146e-004
|
||||
-6.7776908e-004
|
||||
-6.5403334e-004
|
||||
-6.3124935e-004
|
||||
-6.1327474e-004
|
||||
-5.8709305e-004
|
||||
-5.6778026e-004
|
||||
-5.4665656e-004
|
||||
-5.2265643e-004
|
||||
-5.0407143e-004
|
||||
-4.8937912e-004
|
||||
-4.8752280e-004
|
||||
-4.9475181e-004
|
||||
-5.6176926e-004
|
||||
-5.5252865e-004
|
||||
@@ -0,0 +1,102 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Stretch2d(torch.nn.Module):
|
||||
def __init__(self, x_scale, y_scale, mode="nearest"):
|
||||
super().__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x (Tensor): Input tensor (B, C, F, T).
|
||||
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
||||
"""
|
||||
return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
|
||||
|
||||
|
||||
class UpsampleNetwork(torch.nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
upsample_factors,
|
||||
nonlinear_activation=None,
|
||||
nonlinear_activation_params={},
|
||||
interpolate_mode="nearest",
|
||||
freq_axis_kernel_size=1,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_causal_conv = use_causal_conv
|
||||
self.up_layers = torch.nn.ModuleList()
|
||||
for scale in upsample_factors:
|
||||
# interpolation layer
|
||||
stretch = Stretch2d(scale, 1, interpolate_mode)
|
||||
self.up_layers += [stretch]
|
||||
|
||||
# conv layer
|
||||
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
|
||||
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
||||
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
||||
if use_causal_conv:
|
||||
padding = (freq_axis_padding, scale * 2)
|
||||
else:
|
||||
padding = (freq_axis_padding, scale)
|
||||
conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
||||
self.up_layers += [conv]
|
||||
|
||||
# nonlinear
|
||||
if nonlinear_activation is not None:
|
||||
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||
self.up_layers += [nonlinear]
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c : (B, C, T_in).
|
||||
Tensor: (B, C, T_upsample)
|
||||
"""
|
||||
c = c.unsqueeze(1) # (B, 1, C, T)
|
||||
for f in self.up_layers:
|
||||
c = f(c)
|
||||
return c.squeeze(1) # (B, C, T')
|
||||
|
||||
|
||||
class ConvUpsample(torch.nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
upsample_factors,
|
||||
nonlinear_activation=None,
|
||||
nonlinear_activation_params={},
|
||||
interpolate_mode="nearest",
|
||||
freq_axis_kernel_size=1,
|
||||
aux_channels=80,
|
||||
aux_context_window=0,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.aux_context_window = aux_context_window
|
||||
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
||||
# To capture wide-context information in conditional features
|
||||
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
||||
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
|
||||
self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
|
||||
self.upsample = UpsampleNetwork(
|
||||
upsample_factors=upsample_factors,
|
||||
nonlinear_activation=nonlinear_activation,
|
||||
nonlinear_activation_params=nonlinear_activation_params,
|
||||
interpolate_mode=interpolate_mode,
|
||||
freq_axis_kernel_size=freq_axis_kernel_size,
|
||||
use_causal_conv=use_causal_conv,
|
||||
)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c : (B, C, T_in).
|
||||
Tensor: (B, C, T_upsampled),
|
||||
"""
|
||||
c_ = self.conv_in(c)
|
||||
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
|
||||
return self.upsample(c)
|
||||
@@ -0,0 +1,166 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
nn.init.orthogonal_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Positional encoding with noise level conditioning"""
|
||||
|
||||
def __init__(self, n_channels, max_len=10000):
|
||||
super().__init__()
|
||||
self.n_channels = n_channels
|
||||
self.max_len = max_len
|
||||
self.C = 5000
|
||||
self.pe = torch.zeros(0, 0)
|
||||
|
||||
def forward(self, x, noise_level):
|
||||
if x.shape[2] > self.pe.shape[1]:
|
||||
self.init_pe_matrix(x.shape[1], x.shape[2], x)
|
||||
return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C
|
||||
|
||||
def init_pe_matrix(self, n_channels, max_len, x):
|
||||
pe = torch.zeros(max_len, n_channels)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
|
||||
|
||||
pe[:, 0::2] = torch.sin(position / div_term)
|
||||
pe[:, 1::2] = torch.cos(position / div_term)
|
||||
self.pe = pe.transpose(0, 1).to(x)
|
||||
|
||||
|
||||
class FiLM(nn.Module):
|
||||
def __init__(self, input_size, output_size):
|
||||
super().__init__()
|
||||
self.encoding = PositionalEncoding(input_size)
|
||||
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
|
||||
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
|
||||
|
||||
nn.init.xavier_uniform_(self.input_conv.weight)
|
||||
nn.init.xavier_uniform_(self.output_conv.weight)
|
||||
nn.init.zeros_(self.input_conv.bias)
|
||||
nn.init.zeros_(self.output_conv.bias)
|
||||
|
||||
def forward(self, x, noise_scale):
|
||||
o = self.input_conv(x)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.encoding(o, noise_scale)
|
||||
shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
|
||||
return shift, scale
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.input_conv, "weight")
|
||||
remove_parametrizations(self.output_conv, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.input_conv = weight_norm(self.input_conv)
|
||||
self.output_conv = weight_norm(self.output_conv)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def shif_and_scale(x, scale, shift):
|
||||
o = shift + scale * x
|
||||
return o
|
||||
|
||||
|
||||
class UBlock(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, factor, dilation):
|
||||
super().__init__()
|
||||
assert isinstance(dilation, (list, tuple))
|
||||
assert len(dilation) == 4
|
||||
|
||||
self.factor = factor
|
||||
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||
self.main_block = nn.ModuleList(
|
||||
[
|
||||
Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]),
|
||||
]
|
||||
)
|
||||
self.out_block = nn.ModuleList(
|
||||
[
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, shift, scale):
|
||||
x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
|
||||
res = self.res_block(x_inter)
|
||||
o = F.leaky_relu(x_inter, 0.2)
|
||||
o = F.interpolate(o, size=x.shape[-1] * self.factor)
|
||||
o = self.main_block[0](o)
|
||||
o = shif_and_scale(o, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.main_block[1](o)
|
||||
res2 = res + o
|
||||
o = shif_and_scale(res2, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.out_block[0](o)
|
||||
o = shif_and_scale(o, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.out_block[1](o)
|
||||
o = o + res2
|
||||
return o
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.res_block, "weight")
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
remove_parametrizations(layer, "weight")
|
||||
for _, layer in enumerate(self.out_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
remove_parametrizations(layer, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
for idx, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.main_block[idx] = weight_norm(layer)
|
||||
for idx, layer in enumerate(self.out_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.out_block[idx] = weight_norm(layer)
|
||||
|
||||
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, factor):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||
self.main_block = nn.ModuleList(
|
||||
[
|
||||
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
size = x.shape[-1] // self.factor
|
||||
res = self.res_block(x)
|
||||
res = F.interpolate(res, size=size)
|
||||
o = F.interpolate(x, size=size)
|
||||
for layer in self.main_block:
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = layer(o)
|
||||
return o + res
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.res_block, "weight")
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
remove_parametrizations(layer, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
for idx, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.main_block[idx] = weight_norm(layer)
|
||||
@@ -0,0 +1,154 @@
|
||||
import importlib
|
||||
import re
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(config: Coqpit):
|
||||
"""Load models directly from configuration."""
|
||||
if "discriminator_model" in config and "generator_model" in config:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.gan")
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower())
|
||||
if config.model.lower() == "wavernn":
|
||||
MyModel = getattr(MyModel, "Wavernn")
|
||||
elif config.model.lower() == "gan":
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
elif config.model.lower() == "wavegrad":
|
||||
MyModel = getattr(MyModel, "Wavegrad")
|
||||
else:
|
||||
try:
|
||||
MyModel = getattr(MyModel, to_camel(config.model))
|
||||
except ModuleNotFoundError as e:
|
||||
raise ValueError(f"Model {config.model} not exist!") from e
|
||||
print(" > Vocoder Model: {}".format(config.model))
|
||||
return MyModel.init_from_config(config)
|
||||
|
||||
|
||||
def setup_generator(c):
|
||||
"""TODO: use config object as arguments"""
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
# this is to preserve the Wavernn class name (instead of Wavernn)
|
||||
if c.generator_model.lower() in "hifigan_generator":
|
||||
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params)
|
||||
elif c.generator_model.lower() in "melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model in "melgan_fb_generator":
|
||||
raise ValueError("melgan_fb_generator is now fullband_melgan_generator")
|
||||
elif c.generator_model.lower() in "multiband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "fullband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "parallel_wavegan_generator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
stacks=c.generator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=c.audio["num_mels"],
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
)
|
||||
elif c.generator_model.lower() in "univnet_generator":
|
||||
model = MyModel(**c.generator_model_params)
|
||||
else:
|
||||
raise NotImplementedError(f"Model {c.generator_model} not implemented!")
|
||||
return model
|
||||
|
||||
|
||||
def setup_discriminator(c):
|
||||
"""TODO: use config objekt as arguments"""
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||
if c.discriminator_model in "hifigan_discriminator":
|
||||
model = MyModel()
|
||||
if c.discriminator_model in "random_window_discriminator":
|
||||
model = MyModel(
|
||||
cond_channels=c.audio["num_mels"],
|
||||
hop_length=c.audio["hop_length"],
|
||||
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"],
|
||||
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"],
|
||||
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"],
|
||||
window_sizes=c.discriminator_model_params["window_sizes"],
|
||||
)
|
||||
if c.discriminator_model in "melgan_multiscale_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=c.discriminator_model_params["base_channels"],
|
||||
max_channels=c.discriminator_model_params["max_channels"],
|
||||
downsample_factors=c.discriminator_model_params["downsample_factors"],
|
||||
)
|
||||
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
stacks=c.discriminator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
)
|
||||
if c.discriminator_model == "parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
)
|
||||
if c.discriminator_model == "univnet_discriminator":
|
||||
model = MyModel()
|
||||
return model
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,55 @@
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.model import BaseTrainerModel
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseVocoder(BaseTrainerModel):
|
||||
"""Base `vocoder` class. Every new `vocoder` model must inherit this.
|
||||
|
||||
It defines `vocoder` specific functions on top of `Model`.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
MODEL_TYPE = "vocoder"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
||||
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
if "characters" in config:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.num_chars = num_chars
|
||||
if hasattr(self.config, "model_args"):
|
||||
config.model_args.num_chars = num_chars
|
||||
if "model_args" in config:
|
||||
self.args = self.config.model_args
|
||||
# This is for backward compatibility
|
||||
if "model_params" in config:
|
||||
self.args = self.config.model_params
|
||||
else:
|
||||
self.config = config
|
||||
if "model_args" in config:
|
||||
self.args = self.config.model_args
|
||||
# This is for backward compatibility
|
||||
if "model_params" in config:
|
||||
self.args = self.config.model_params
|
||||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
||||
|
||||
class FullbandMelganGenerator(MelganGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=(2, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=4,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
proj_kernel=proj_kernel,
|
||||
base_channels=base_channels,
|
||||
upsample_factors=upsample_factors,
|
||||
res_kernel=res_kernel,
|
||||
num_res_blocks=num_res_blocks,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
return self.layers(cond_features)
|
||||
@@ -0,0 +1,374 @@
|
||||
from inspect import signature
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from TTS.vocoder.models import setup_discriminator, setup_generator
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
class GAN(BaseVocoder):
|
||||
def __init__(self, config: Coqpit, ap: AudioProcessor = None):
|
||||
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
|
||||
It also helps mixing and matching different generator and disciminator networks easily.
|
||||
|
||||
To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest
|
||||
is handled by the `GAN` class.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.
|
||||
|
||||
Examples:
|
||||
Initializing the GAN model with HifiGAN generator and discriminator.
|
||||
>>> from TTS.vocoder.configs import HifiganConfig
|
||||
>>> config = HifiganConfig()
|
||||
>>> model = GAN(config)
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.model_g = setup_generator(config)
|
||||
self.model_d = setup_discriminator(config)
|
||||
self.train_disc = False # if False, train only the generator.
|
||||
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
|
||||
self.ap = ap
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the generator's forward pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: output of the GAN generator network.
|
||||
"""
|
||||
return self.model_g.forward(x)
|
||||
|
||||
def inference(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the generator's inference pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
Returns:
|
||||
torch.Tensor: output of the GAN generator network.
|
||||
"""
|
||||
return self.model_g.inference(x)
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
|
||||
network on the current pass.
|
||||
|
||||
Args:
|
||||
batch (Dict): Batch of samples returned by the dataloader.
|
||||
criterion (Dict): Criterion used to compute the losses.
|
||||
optimizer_idx (int): ID of the optimizer in use on the current pass.
|
||||
|
||||
Raises:
|
||||
ValueError: `optimizer_idx` is an unexpected value.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: model outputs and the computed loss values.
|
||||
"""
|
||||
outputs = {}
|
||||
loss_dict = {}
|
||||
|
||||
x = batch["input"]
|
||||
y = batch["waveform"]
|
||||
|
||||
if optimizer_idx not in [0, 1]:
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# DISCRIMINATOR optimization
|
||||
|
||||
# generator pass
|
||||
y_hat = self.model_g(x)[:, :, : y.size(2)]
|
||||
|
||||
# cache for generator loss
|
||||
# pylint: disable=W0201
|
||||
self.y_hat_g = y_hat
|
||||
self.y_hat_sub = None
|
||||
self.y_sub_g = None
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
self.y_hat_sub = y_hat
|
||||
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
||||
self.y_hat_g = y_hat # save for generator loss
|
||||
self.y_sub_g = self.model_g.pqmf_analysis(y)
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
|
||||
if self.train_disc:
|
||||
# use different samples for G and D trainings
|
||||
if self.config.diff_samples_for_G_and_D:
|
||||
x_d = batch["input_disc"]
|
||||
y_d = batch["waveform_disc"]
|
||||
# use a different sample than generator
|
||||
with torch.no_grad():
|
||||
y_hat = self.model_g(x_d)
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
||||
else:
|
||||
# use the same samples as generator
|
||||
x_d = x.clone()
|
||||
y_d = y.clone()
|
||||
y_hat = self.y_hat_g
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(self.model_d.forward).parameters) == 2:
|
||||
D_out_fake = self.model_d(y_hat.detach().clone(), x_d)
|
||||
D_out_real = self.model_d(y_d, x_d)
|
||||
else:
|
||||
D_out_fake = self.model_d(y_hat.detach())
|
||||
D_out_real = self.model_d(y_d)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
# self.model_d returns scores and features
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
# model D returns only scores
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
|
||||
outputs = {"model_outputs": y_hat}
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# GENERATOR loss
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if self.train_disc:
|
||||
if len(signature(self.model_d.forward).parameters) == 2:
|
||||
D_out_fake = self.model_d(self.y_hat_g, x)
|
||||
else:
|
||||
D_out_fake = self.model_d(self.y_hat_g)
|
||||
D_out_real = None
|
||||
|
||||
if self.config.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = self.model_d(y)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
feats_fake, feats_real = None, None
|
||||
|
||||
# compute losses
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
|
||||
)
|
||||
outputs = {"model_outputs": self.y_hat_g}
|
||||
return outputs, loss_dict
|
||||
|
||||
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Logging shared by the training and evaluation.
|
||||
|
||||
Args:
|
||||
name (str): Name of the run. `train` or `eval`,
|
||||
ap (AudioProcessor): Audio processor used in training.
|
||||
batch (Dict): Batch used in the last train/eval step.
|
||||
outputs (Dict): Model outputs from the last train/eval step.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: log figures and audio samples.
|
||||
"""
|
||||
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
|
||||
y = batch["waveform"]
|
||||
figures = plot_results(y_hat, y, ap, name)
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
audios = {f"{name}/audio": sample_voice}
|
||||
return figures, audios
|
||||
|
||||
def train_log(
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for training."""
|
||||
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Call `train_step()` with `no_grad()`"""
|
||||
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for evaluation."""
|
||||
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config: Coqpit,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
|
||||
cache: bool = False,
|
||||
) -> None:
|
||||
"""Load a GAN checkpoint and initialize model parameters.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model config.
|
||||
checkpoint_path (str): Checkpoint file path.
|
||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||
"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
# band-aid for older than v0.0.15 GAN models
|
||||
if "model_disc" in state:
|
||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||
else:
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.model_d = None
|
||||
if hasattr(self.model_g, "remove_weight_norm"):
|
||||
self.model_g.remove_weight_norm()
|
||||
|
||||
def on_train_step_start(self, trainer) -> None:
|
||||
"""Enable the discriminator training based on `steps_to_start_discriminator`
|
||||
|
||||
Args:
|
||||
trainer (Trainer): Trainer object.
|
||||
"""
|
||||
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||
|
||||
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
||||
|
||||
Returns:
|
||||
List: optimizers.
|
||||
"""
|
||||
optimizer1 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
|
||||
)
|
||||
optimizer2 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
|
||||
)
|
||||
return [optimizer2, optimizer1]
|
||||
|
||||
def get_lr(self) -> List:
|
||||
"""Set the initial learning rates for each optimizer.
|
||||
|
||||
Returns:
|
||||
List: learning rates for each optimizer.
|
||||
"""
|
||||
return [self.config.lr_disc, self.config.lr_gen]
|
||||
|
||||
def get_scheduler(self, optimizer) -> List:
|
||||
"""Set the schedulers for each optimizer.
|
||||
|
||||
Args:
|
||||
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
|
||||
|
||||
Returns:
|
||||
List: Schedulers, one for each optimizer.
|
||||
"""
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler2, scheduler1]
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: List) -> Dict:
|
||||
"""Format the batch for training.
|
||||
|
||||
Args:
|
||||
batch (List): Batch out of the dataloader.
|
||||
|
||||
Returns:
|
||||
Dict: formatted model inputs.
|
||||
"""
|
||||
if isinstance(batch[0], list):
|
||||
x_G, y_G = batch[0]
|
||||
x_D, y_D = batch[1]
|
||||
return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D}
|
||||
x, y = batch
|
||||
return {"input": x, "waveform": y}
|
||||
|
||||
def get_data_loader( # pylint: disable=no-self-use, unused-argument
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: True,
|
||||
samples: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None, # pylint: disable=unused-argument
|
||||
):
|
||||
"""Initiate and return the GAN dataloader.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model config.
|
||||
ap (AudioProcessor): Audio processor.
|
||||
is_eval (True): Set the dataloader for evaluation if true.
|
||||
samples (List): Data samples.
|
||||
verbose (bool): Log information if true.
|
||||
num_gpus (int): Number of GPUs in use.
|
||||
rank (int): Rank of the current GPU. Defaults to None.
|
||||
|
||||
Returns:
|
||||
DataLoader: Torch dataloader.
|
||||
"""
|
||||
dataset = GANDataset(
|
||||
ap=self.ap,
|
||||
items=samples,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=self.ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||
is_training=not is_eval,
|
||||
return_segments=not is_eval,
|
||||
use_noise_augment=config.use_noise_augment,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_eval else config.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_criterion(self):
|
||||
"""Return criterions for the optimizers"""
|
||||
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
|
||||
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
||||
return GAN(config, ap=ap)
|
||||
@@ -0,0 +1,217 @@
|
||||
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
"""HiFiGAN Periodic Discriminator
|
||||
|
||||
Takes every Pth value from the input waveform and applied a stack of convoluations.
|
||||
|
||||
Note:
|
||||
if `period` is 2
|
||||
`waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat`
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[Tensor]: discriminator scores per sample in the batch.
|
||||
[List[Tensor]]: list of features from each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
x: [B, 1, T]
|
||||
"""
|
||||
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
get_padding = lambda k, d: int((k * d - d) / 2)
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[Tensor]: discriminator scores per sample in the batch.
|
||||
[List[Tensor]]: list of features from each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
x: [B, 1, T]
|
||||
"""
|
||||
feat = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
feat.append(x)
|
||||
x = self.conv_post(x)
|
||||
feat.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, feat
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
"""HiFiGAN Multi-Period Discriminator (MPD)
|
||||
Wrapper for the `PeriodDiscriminator` to apply it in different periods.
|
||||
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[List[Tensor]]: list of scores from each discriminator.
|
||||
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
x: [B, 1, T]
|
||||
"""
|
||||
scores = []
|
||||
feats = []
|
||||
for _, d in enumerate(self.discriminators):
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
return scores, feats
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
"""HiFiGAN Scale Discriminator.
|
||||
It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper.
|
||||
|
||||
Args:
|
||||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
Tensor: discriminator scores.
|
||||
List[Tensor]: list of features from the convolutiona layers.
|
||||
"""
|
||||
feat = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
feat.append(x)
|
||||
x = self.conv_post(x)
|
||||
feat.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
return x, feat
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(torch.nn.Module):
|
||||
"""HiFiGAN Multi-Scale Discriminator.
|
||||
It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
]
|
||||
)
|
||||
self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)])
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores = []
|
||||
feats = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
x = self.meanpools[i - 1](x)
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
return scores, feats
|
||||
|
||||
|
||||
class HifiganDiscriminator(nn.Module):
|
||||
"""HiFiGAN discriminator wrapping MPD and MSD."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.msd = MultiScaleDiscriminator()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores, feats = self.mpd(x)
|
||||
scores_, feats_ = self.msd(x)
|
||||
return scores + scores_, feats + feats_
|
||||
@@ -0,0 +1,301 @@
|
||||
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def get_padding(k, d):
|
||||
return int((k * d - d) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|
||||
|--------------------------------------------------------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input tensor.
|
||||
Returns:
|
||||
Tensor: output tensor.
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
"""
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|
||||
|---------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super().__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class HifiganGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
resblock_type,
|
||||
resblock_dilation_sizes,
|
||||
resblock_kernel_sizes,
|
||||
upsample_kernel_sizes,
|
||||
upsample_initial_channel,
|
||||
upsample_factors,
|
||||
inference_padding=5,
|
||||
cond_channels=0,
|
||||
conv_pre_weight_norm=True,
|
||||
conv_post_weight_norm=True,
|
||||
conv_post_bias=True,
|
||||
):
|
||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||
|
||||
Network:
|
||||
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
|
||||
.. -> zI ---|
|
||||
resblockN_kNx1 -> zN ---'
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
resblock_type (str): type of the `ResBlock`. '1' or '2'.
|
||||
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
|
||||
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
|
||||
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
|
||||
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
|
||||
for each consecutive upsampling layer.
|
||||
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
|
||||
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inference_padding = inference_padding
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_factors)
|
||||
# initial upsampling layers
|
||||
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||
# upsampling layers
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
# MRF blocks
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
# post convolution layer
|
||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||
if cond_channels > 0:
|
||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
if not conv_pre_weight_norm:
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
|
||||
if not conv_post_weight_norm:
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor.
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
o = self.conv_pre(x)
|
||||
if hasattr(self, "cond_layer"):
|
||||
o = o + self.cond_layer(g)
|
||||
for i in range(self.num_upsamples):
|
||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||
o = self.ups[i](o)
|
||||
z_sum = None
|
||||
for j in range(self.num_kernels):
|
||||
if z_sum is None:
|
||||
z_sum = self.resblocks[i * self.num_kernels + j](o)
|
||||
else:
|
||||
z_sum += self.resblocks[i * self.num_kernels + j](o)
|
||||
o = z_sum / self.num_kernels
|
||||
o = F.leaky_relu(o)
|
||||
o = self.conv_post(o)
|
||||
o = torch.tanh(o)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
@@ -0,0 +1,84 @@
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
|
||||
class MelganDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4, 4),
|
||||
groups_denominator=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
layer_kernel_size = np.prod(kernel_sizes)
|
||||
layer_padding = (layer_kernel_size - 1) // 2
|
||||
|
||||
# initial layer
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
]
|
||||
|
||||
# downsampling layers
|
||||
layer_in_channels = base_channels
|
||||
for downsample_factor in downsample_factors:
|
||||
layer_out_channels = min(layer_in_channels * downsample_factor, max_channels)
|
||||
layer_kernel_size = downsample_factor * 10 + 1
|
||||
layer_padding = (layer_kernel_size - 1) // 2
|
||||
layer_groups = layer_in_channels // groups_denominator
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
layer_in_channels,
|
||||
layer_out_channels,
|
||||
kernel_size=layer_kernel_size,
|
||||
stride=downsample_factor,
|
||||
padding=layer_padding,
|
||||
groups=layer_groups,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
]
|
||||
layer_in_channels = layer_out_channels
|
||||
|
||||
# last 2 layers
|
||||
layer_padding1 = (kernel_sizes[0] - 1) // 2
|
||||
layer_padding2 = (kernel_sizes[1] - 1) // 2
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
layer_out_channels,
|
||||
layer_out_channels,
|
||||
kernel_size=kernel_sizes[0],
|
||||
stride=1,
|
||||
padding=layer_padding1,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
feats = []
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
feats.append(x)
|
||||
return x, feats
|
||||
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.melgan import ResidualStack
|
||||
|
||||
|
||||
class MelganGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=(8, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=3,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# assert model parameters
|
||||
assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number."
|
||||
|
||||
# setup additional model parameters
|
||||
base_padding = (proj_kernel - 1) // 2
|
||||
act_slope = 0.2
|
||||
self.inference_padding = 2
|
||||
|
||||
# initial layer
|
||||
layers = []
|
||||
layers += [
|
||||
nn.ReflectionPad1d(base_padding),
|
||||
weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)),
|
||||
]
|
||||
|
||||
# upsampling layers and residual stacks
|
||||
for idx, upsample_factor in enumerate(upsample_factors):
|
||||
layer_in_channels = base_channels // (2**idx)
|
||||
layer_out_channels = base_channels // (2 ** (idx + 1))
|
||||
layer_filter_size = upsample_factor * 2
|
||||
layer_stride = upsample_factor
|
||||
layer_output_padding = upsample_factor % 2
|
||||
layer_padding = upsample_factor // 2 + layer_output_padding
|
||||
layers += [
|
||||
nn.LeakyReLU(act_slope),
|
||||
weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
layer_in_channels,
|
||||
layer_out_channels,
|
||||
layer_filter_size,
|
||||
stride=layer_stride,
|
||||
padding=layer_padding,
|
||||
output_padding=layer_output_padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel),
|
||||
]
|
||||
|
||||
layers += [nn.LeakyReLU(act_slope)]
|
||||
|
||||
# final layer
|
||||
layers += [
|
||||
nn.ReflectionPad1d(base_padding),
|
||||
weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)),
|
||||
nn.Tanh(),
|
||||
]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, c):
|
||||
return self.layers(c)
|
||||
|
||||
def inference(self, c):
|
||||
c = c.to(self.layers[1].weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.layers(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for _, layer in enumerate(self.layers):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.parametrize.remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
@@ -0,0 +1,50 @@
|
||||
from torch import nn
|
||||
|
||||
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
|
||||
|
||||
|
||||
class MelganMultiscaleDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
num_scales=3,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4),
|
||||
pooling_kernel_size=4,
|
||||
pooling_stride=2,
|
||||
pooling_padding=2,
|
||||
groups_denominator=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
MelganDiscriminator(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_sizes=kernel_sizes,
|
||||
base_channels=base_channels,
|
||||
max_channels=max_channels,
|
||||
downsample_factors=downsample_factors,
|
||||
groups_denominator=groups_denominator,
|
||||
)
|
||||
for _ in range(num_scales)
|
||||
]
|
||||
)
|
||||
|
||||
self.pooling = nn.AvgPool1d(
|
||||
kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores = []
|
||||
feats = []
|
||||
for disc in self.discriminators:
|
||||
score, feat = disc(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
x = self.pooling(x)
|
||||
return scores, feats
|
||||
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.layers.pqmf import PQMF
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
||||
|
||||
class MultibandMelganGenerator(MelganGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=(2, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=3,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
proj_kernel=proj_kernel,
|
||||
base_channels=base_channels,
|
||||
upsample_factors=upsample_factors,
|
||||
res_kernel=res_kernel,
|
||||
num_res_blocks=num_res_blocks,
|
||||
)
|
||||
self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
|
||||
|
||||
def pqmf_analysis(self, x):
|
||||
return self.pqmf_layer.analysis(x)
|
||||
|
||||
def pqmf_synthesis(self, x):
|
||||
return self.pqmf_layer.synthesis(x)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
return self.pqmf_synthesis(self.layers(cond_features))
|
||||
@@ -0,0 +1,187 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
|
||||
|
||||
class ParallelWaveganDiscriminator(nn.Module):
|
||||
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
|
||||
It classifies each audio window real/fake and returns a sequence
|
||||
of predictions.
|
||||
It is a stack of convolutional blocks with dilation.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=10,
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size."
|
||||
assert dilation_factor > 0, " [!] dilation factor must be > 0."
|
||||
self.conv_layers = nn.ModuleList()
|
||||
conv_in_channels = in_channels
|
||||
for i in range(num_layers - 1):
|
||||
if i == 0:
|
||||
dilation = 1
|
||||
else:
|
||||
dilation = i if dilation_factor == 1 else dilation_factor**i
|
||||
conv_in_channels = conv_channels
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
conv_layer = [
|
||||
nn.Conv1d(
|
||||
conv_in_channels,
|
||||
conv_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
]
|
||||
self.conv_layers += conv_layer
|
||||
padding = (kernel_size - 1) // 2
|
||||
last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
||||
self.conv_layers += [last_conv_layer]
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x : (B, 1, T).
|
||||
Returns:
|
||||
Tensor: (B, 1, T)
|
||||
"""
|
||||
for f in self.conv_layers:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
class ResidualParallelWaveganDiscriminator(nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
):
|
||||
super().__init__()
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_layers = num_layers
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.res_factor = math.sqrt(1.0 / num_layers)
|
||||
|
||||
# check the number of num_layers and stacks
|
||||
assert num_layers % stacks == 0
|
||||
layers_per_stack = num_layers // stacks
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = nn.Sequential(
|
||||
nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for layer in range(num_layers):
|
||||
dilation = 2 ** (layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
res_channels=res_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=-1,
|
||||
dilation=dilation,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
use_causal_conv=False,
|
||||
)
|
||||
self.conv_layers += [conv]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = nn.ModuleList(
|
||||
[
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (B, 1, T).
|
||||
"""
|
||||
x = self.first_conv(x)
|
||||
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, None)
|
||||
skips += h
|
||||
skips *= self.res_factor
|
||||
|
||||
# apply final layers
|
||||
x = skips
|
||||
for f in self.last_conv_layers:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
print(f"Weight norm is removed from {m}.")
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
@@ -0,0 +1,164 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||
|
||||
|
||||
class ParallelWaveganGenerator(torch.nn.Module):
|
||||
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
|
||||
It is similar to WaveNet with no causal convolution.
|
||||
It is conditioned on an aux feature (spectrogram) to generate
|
||||
an output waveform from an input noise.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=[4, 4, 4, 4],
|
||||
inference_padding=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.aux_channels = aux_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.upsample_factors = upsample_factors
|
||||
self.upsample_scale = np.prod(upsample_factors)
|
||||
self.inference_padding = inference_padding
|
||||
self.use_weight_norm = use_weight_norm
|
||||
|
||||
# check the number of layers and stacks
|
||||
assert num_res_blocks % stacks == 0
|
||||
layers_per_stack = num_res_blocks // stacks
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True)
|
||||
|
||||
# define conv + upsampling network
|
||||
self.upsample_net = ConvUpsample(upsample_factors=upsample_factors)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
for layer in range(num_res_blocks):
|
||||
dilation = 2 ** (layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
res_channels=res_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=aux_channels,
|
||||
dilation=dilation,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
)
|
||||
self.conv_layers += [conv]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c: (B, C ,T').
|
||||
o: Output tensor (B, out_channels, T)
|
||||
"""
|
||||
# random noise
|
||||
x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
|
||||
# perform upsampling
|
||||
if c is not None and self.upsample_net is not None:
|
||||
c = self.upsample_net(c)
|
||||
assert (
|
||||
c.shape[-1] == x.shape[-1]
|
||||
), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"
|
||||
|
||||
# encode to hidden representation
|
||||
x = self.first_conv(x)
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, c)
|
||||
skips += h
|
||||
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
||||
|
||||
# apply final layers
|
||||
x = skips
|
||||
for f in self.last_conv_layers:
|
||||
x = f(x)
|
||||
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
c = c.to(self.first_conv.weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
if self.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
@@ -0,0 +1,203 @@
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, cond_channels, downsample_factor):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.downsample_factor = downsample_factor
|
||||
|
||||
self.start = nn.Sequential(
|
||||
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1),
|
||||
)
|
||||
self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1)
|
||||
self.end = nn.Sequential(
|
||||
nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2)
|
||||
)
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv1d(in_channels, in_channels * 2, kernel_size=1),
|
||||
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
|
||||
)
|
||||
|
||||
def forward(self, inputs, conditions):
|
||||
outputs = self.start(inputs) + self.lc_conv1d(conditions)
|
||||
outputs = self.end(outputs)
|
||||
residual_outputs = self.residual(inputs)
|
||||
outputs = outputs + residual_outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, downsample_factor):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_factor = downsample_factor
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor)
|
||||
self.layers = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2),
|
||||
)
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.downsample_factor > 1:
|
||||
outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs))
|
||||
else:
|
||||
outputs = self.layers(inputs) + self.residual(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class ConditionalDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)):
|
||||
super().__init__()
|
||||
|
||||
assert len(downsample_factors) == len(out_channels) + 1
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.downsample_factors = downsample_factors
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.pre_cond_layers = nn.ModuleList()
|
||||
self.post_cond_layers = nn.ModuleList()
|
||||
|
||||
# layers before condition features
|
||||
self.pre_cond_layers += [DBlock(in_channels, 64, 1)]
|
||||
in_channels = 64
|
||||
for i, channel in enumerate(out_channels):
|
||||
self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i]))
|
||||
in_channels = channel
|
||||
|
||||
# condition block
|
||||
self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1])
|
||||
|
||||
# layers after condition block
|
||||
self.post_cond_layers += [
|
||||
DBlock(in_channels * 2, in_channels * 2, 1),
|
||||
DBlock(in_channels * 2, in_channels * 2, 1),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
nn.Conv1d(in_channels * 2, 1, kernel_size=1),
|
||||
]
|
||||
|
||||
def forward(self, inputs, conditions):
|
||||
batch_size = inputs.size()[0]
|
||||
outputs = inputs.view(batch_size, self.in_channels, -1)
|
||||
for layer in self.pre_cond_layers:
|
||||
outputs = layer(outputs)
|
||||
outputs = self.cond_block(outputs, conditions)
|
||||
for layer in self.post_cond_layers:
|
||||
outputs = layer(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class UnconditionalDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)):
|
||||
super().__init__()
|
||||
|
||||
self.downsample_factors = downsample_factors
|
||||
self.in_channels = in_channels
|
||||
self.downsample_factors = downsample_factors
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers += [DBlock(self.in_channels, base_channels, 1)]
|
||||
in_channels = base_channels
|
||||
for i, factor in enumerate(downsample_factors):
|
||||
self.layers.append(DBlock(in_channels, out_channels[i], factor))
|
||||
in_channels *= 2
|
||||
self.layers += [
|
||||
DBlock(in_channels, in_channels, 1),
|
||||
DBlock(in_channels, in_channels, 1),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
nn.Conv1d(in_channels, 1, kernel_size=1),
|
||||
]
|
||||
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.size()[0]
|
||||
outputs = inputs.view(batch_size, self.in_channels, -1)
|
||||
for layer in self.layers:
|
||||
outputs = layer(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class RandomWindowDiscriminator(nn.Module):
|
||||
"""Random Window Discriminator as described in
|
||||
http://arxiv.org/abs/1909.11646"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cond_channels,
|
||||
hop_length,
|
||||
uncond_disc_donwsample_factors=(8, 4),
|
||||
cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)),
|
||||
cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)),
|
||||
window_sizes=(512, 1024, 2048, 4096, 8192),
|
||||
):
|
||||
super().__init__()
|
||||
self.cond_channels = cond_channels
|
||||
self.window_sizes = window_sizes
|
||||
self.hop_length = hop_length
|
||||
self.base_window_size = self.hop_length * 2
|
||||
self.ks = [ws // self.base_window_size for ws in window_sizes]
|
||||
|
||||
# check arguments
|
||||
assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes)
|
||||
for ws in window_sizes:
|
||||
assert ws % hop_length == 0
|
||||
|
||||
for idx, cf in enumerate(cond_disc_downsample_factors):
|
||||
assert np.prod(cf) == hop_length // self.ks[idx]
|
||||
|
||||
# define layers
|
||||
self.unconditional_discriminators = nn.ModuleList([])
|
||||
for k in self.ks:
|
||||
layer = UnconditionalDiscriminator(
|
||||
in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors
|
||||
)
|
||||
self.unconditional_discriminators.append(layer)
|
||||
|
||||
self.conditional_discriminators = nn.ModuleList([])
|
||||
for idx, k in enumerate(self.ks):
|
||||
layer = ConditionalDiscriminator(
|
||||
in_channels=k,
|
||||
cond_channels=cond_channels,
|
||||
downsample_factors=cond_disc_downsample_factors[idx],
|
||||
out_channels=cond_disc_out_channels[idx],
|
||||
)
|
||||
self.conditional_discriminators.append(layer)
|
||||
|
||||
def forward(self, x, c):
|
||||
scores = []
|
||||
feats = []
|
||||
# unconditional pass
|
||||
for window_size, layer in zip(self.window_sizes, self.unconditional_discriminators):
|
||||
index = np.random.randint(x.shape[-1] - window_size)
|
||||
|
||||
score = layer(x[:, :, index : index + window_size])
|
||||
scores.append(score)
|
||||
|
||||
# conditional pass
|
||||
for window_size, layer in zip(self.window_sizes, self.conditional_discriminators):
|
||||
frame_size = window_size // self.hop_length
|
||||
lc_index = np.random.randint(c.shape[-1] - frame_size)
|
||||
sample_index = lc_index * self.hop_length
|
||||
x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length]
|
||||
c_sub = c[:, :, lc_index : lc_index + frame_size]
|
||||
|
||||
score = layer(x_sub, c_sub)
|
||||
scores.append(score)
|
||||
return scores, feats
|
||||
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class SpecDiscriminator(nn.Module):
|
||||
"""docstring for Discriminator."""
|
||||
|
||||
def __init__(self, fft_size=1024, hop_length=120, win_length=600, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.fft_size = fft_size
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.stft = TorchSTFT(fft_size, hop_length, win_length)
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
|
||||
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, y):
|
||||
fmap = []
|
||||
with torch.no_grad():
|
||||
y = y.squeeze(1)
|
||||
y = self.stft(y)
|
||||
y = y.unsqueeze(1)
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y = d(y)
|
||||
y = F.leaky_relu(y, LRELU_SLOPE)
|
||||
fmap.append(y)
|
||||
|
||||
y = self.out(y)
|
||||
fmap.append(y)
|
||||
|
||||
return torch.flatten(y, 1, -1), fmap
|
||||
|
||||
|
||||
class MultiResSpecDiscriminator(torch.nn.Module):
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann_window"
|
||||
):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
||||
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
||||
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores = []
|
||||
feats = []
|
||||
for d in self.discriminators:
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
|
||||
return scores, feats
|
||||
|
||||
|
||||
class UnivnetDiscriminator(nn.Module):
|
||||
"""Univnet discriminator wrapping MPD and MSD."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.msd = MultiResSpecDiscriminator()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores, feats = self.mpd(x)
|
||||
scores_, feats_ = self.msd(x)
|
||||
return scores + scores_, feats + feats_
|
||||
@@ -0,0 +1,157 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class UnivnetGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int,
|
||||
cond_channels: int,
|
||||
upsample_factors: List[int],
|
||||
lvc_layers_each_block: int,
|
||||
lvc_kernel_size: int,
|
||||
kpnet_hidden_channels: int,
|
||||
kpnet_conv_size: int,
|
||||
dropout: float,
|
||||
use_weight_norm=True,
|
||||
):
|
||||
"""Univnet Generator network.
|
||||
|
||||
Paper: https://arxiv.org/pdf/2106.07889.pdf
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input tensor channels.
|
||||
out_channels (int): Number of channels of the output tensor.
|
||||
hidden_channels (int): Number of hidden network channels.
|
||||
cond_channels (int): Number of channels of the conditioning tensors.
|
||||
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
|
||||
lvc_layers_each_block (int): Number of LVC layers in each block.
|
||||
lvc_kernel_size (int): Kernel size of the LVC layers.
|
||||
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
|
||||
kpnet_conv_size (int): Number of convolution channels in the key-point network.
|
||||
dropout (float): Dropout rate.
|
||||
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.upsample_scale = np.prod(upsample_factors)
|
||||
self.lvc_block_nums = len(upsample_factors)
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = torch.nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
|
||||
)
|
||||
|
||||
# define residual blocks
|
||||
self.lvc_blocks = torch.nn.ModuleList()
|
||||
cond_hop_length = 1
|
||||
for n in range(self.lvc_block_nums):
|
||||
cond_hop_length = cond_hop_length * upsample_factors[n]
|
||||
lvcb = LVCBlock(
|
||||
in_channels=hidden_channels,
|
||||
cond_channels=cond_channels,
|
||||
upsample_ratio=upsample_factors[n],
|
||||
conv_layers=lvc_layers_each_block,
|
||||
conv_kernel_size=lvc_kernel_size,
|
||||
cond_hop_length=cond_hop_length,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=dropout,
|
||||
)
|
||||
self.lvc_blocks += [lvcb]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Conv1d(
|
||||
hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, c):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
c (Tensor): Local conditioning auxiliary features (B, C ,T').
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T)
|
||||
"""
|
||||
# random noise
|
||||
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
x = self.first_conv(x)
|
||||
|
||||
for n in range(self.lvc_block_nums):
|
||||
x = self.lvc_blocks[n](x, c)
|
||||
|
||||
# apply final layers
|
||||
for f in self.last_conv_layers:
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = f(x)
|
||||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization module from all of the layers."""
|
||||
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
parametrize.remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
"""Apply weight normalization module from all of the layers."""
|
||||
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
"""Return receptive field size."""
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""Perform inference.
|
||||
Args:
|
||||
c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`.
|
||||
Returns:
|
||||
Tensor: Output tensor (T, out_channels)
|
||||
"""
|
||||
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
|
||||
c = c.to(next(self.parameters()))
|
||||
return self.forward(c)
|
||||
@@ -0,0 +1,345 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavegradArgs(Coqpit):
|
||||
in_channels: int = 80
|
||||
out_channels: int = 1
|
||||
use_weight_norm: bool = False
|
||||
y_conv_channels: int = 32
|
||||
x_conv_channels: int = 768
|
||||
dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512])
|
||||
ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
|
||||
upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
|
||||
upsample_dilations: List[List[int]] = field(
|
||||
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]]
|
||||
)
|
||||
|
||||
|
||||
class Wavegrad(BaseVocoder):
|
||||
"""🐸 🌊 WaveGrad 🌊 model.
|
||||
Paper - https://arxiv.org/abs/2009.00713
|
||||
|
||||
Examples:
|
||||
Initializing the model.
|
||||
|
||||
>>> from TTS.vocoder.configs import WavegradConfig
|
||||
>>> config = WavegradConfig()
|
||||
>>> model = Wavegrad(config)
|
||||
|
||||
Paper Abstract:
|
||||
This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the
|
||||
data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts
|
||||
from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned
|
||||
on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting
|
||||
the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in
|
||||
terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations.
|
||||
Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive
|
||||
baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations.
|
||||
Audio samples are available at this https URL.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.use_weight_norm = config.model_params.use_weight_norm
|
||||
self.hop_len = np.prod(config.model_params.upsample_factors)
|
||||
self.noise_level = None
|
||||
self.num_steps = None
|
||||
self.beta = None
|
||||
self.alpha = None
|
||||
self.alpha_hat = None
|
||||
self.c1 = None
|
||||
self.c2 = None
|
||||
self.sigma = None
|
||||
|
||||
# dblocks
|
||||
self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2)
|
||||
self.dblocks = nn.ModuleList([])
|
||||
ic = config.model_params.y_conv_channels
|
||||
for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)):
|
||||
self.dblocks.append(DBlock(ic, oc, df))
|
||||
ic = oc
|
||||
|
||||
# film
|
||||
self.film = nn.ModuleList([])
|
||||
ic = config.model_params.y_conv_channels
|
||||
for oc in reversed(config.model_params.ublock_out_channels):
|
||||
self.film.append(FiLM(ic, oc))
|
||||
ic = oc
|
||||
|
||||
# ublocksn
|
||||
self.ublocks = nn.ModuleList([])
|
||||
ic = config.model_params.x_conv_channels
|
||||
for oc, uf, ud in zip(
|
||||
config.model_params.ublock_out_channels,
|
||||
config.model_params.upsample_factors,
|
||||
config.model_params.upsample_dilations,
|
||||
):
|
||||
self.ublocks.append(UBlock(ic, oc, uf, ud))
|
||||
ic = oc
|
||||
|
||||
self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1)
|
||||
self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1)
|
||||
|
||||
if config.model_params.use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x, spectrogram, noise_scale):
|
||||
shift_and_scale = []
|
||||
|
||||
x = self.y_conv(x)
|
||||
shift_and_scale.append(self.film[0](x, noise_scale))
|
||||
|
||||
for film, layer in zip(self.film[1:], self.dblocks):
|
||||
x = layer(x)
|
||||
shift_and_scale.append(film(x, noise_scale))
|
||||
|
||||
x = self.x_conv(spectrogram)
|
||||
for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)):
|
||||
x = layer(x, film_shift, film_scale)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
def load_noise_schedule(self, path):
|
||||
beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg
|
||||
self.compute_noise_level(beta)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, y_n=None):
|
||||
"""
|
||||
Shapes:
|
||||
x: :math:`[B, C , T]`
|
||||
y_n: :math:`[B, 1, T]`
|
||||
"""
|
||||
if y_n is None:
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
|
||||
else:
|
||||
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
|
||||
y_n = y_n.type_as(x)
|
||||
sqrt_alpha_hat = self.noise_level.to(x)
|
||||
for n in range(len(self.alpha) - 1, -1, -1):
|
||||
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
|
||||
if n > 0:
|
||||
z = torch.randn_like(y_n)
|
||||
y_n += self.sigma[n - 1] * z
|
||||
y_n.clamp_(-1.0, 1.0)
|
||||
return y_n
|
||||
|
||||
def compute_y_n(self, y_0):
|
||||
"""Compute noisy audio based on noise schedule"""
|
||||
self.noise_level = self.noise_level.to(y_0)
|
||||
if len(y_0.shape) == 3:
|
||||
y_0 = y_0.squeeze(1)
|
||||
s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]])
|
||||
l_a, l_b = self.noise_level[s], self.noise_level[s + 1]
|
||||
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
|
||||
noise_scale = noise_scale.unsqueeze(1)
|
||||
noise = torch.randn_like(y_0)
|
||||
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise
|
||||
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
|
||||
|
||||
def compute_noise_level(self, beta):
|
||||
"""Compute noise schedule parameters"""
|
||||
self.num_steps = len(beta)
|
||||
alpha = 1 - beta
|
||||
alpha_hat = np.cumprod(alpha)
|
||||
noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0)
|
||||
noise_level = alpha_hat**0.5
|
||||
|
||||
# pylint: disable=not-callable
|
||||
self.beta = torch.tensor(beta.astype(np.float32))
|
||||
self.alpha = torch.tensor(alpha.astype(np.float32))
|
||||
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
|
||||
self.noise_level = torch.tensor(noise_level.astype(np.float32))
|
||||
|
||||
self.c1 = 1 / self.alpha**0.5
|
||||
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5
|
||||
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for _, layer in enumerate(self.dblocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.film):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.ublocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
remove_parametrizations(self.x_conv, "weight")
|
||||
remove_parametrizations(self.out_conv, "weight")
|
||||
remove_parametrizations(self.y_conv, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
for _, layer in enumerate(self.dblocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.film):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.ublocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
self.x_conv = weight_norm(self.x_conv)
|
||||
self.out_conv = weight_norm(self.out_conv)
|
||||
self.y_conv = weight_norm(self.y_conv)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
if self.config.model_params.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
betas = np.linspace(
|
||||
config["test_noise_schedule"]["min_val"],
|
||||
config["test_noise_schedule"]["max_val"],
|
||||
config["test_noise_schedule"]["num_steps"],
|
||||
)
|
||||
self.compute_noise_level(betas)
|
||||
else:
|
||||
betas = np.linspace(
|
||||
config["train_noise_schedule"]["min_val"],
|
||||
config["train_noise_schedule"]["max_val"],
|
||||
config["train_noise_schedule"]["num_steps"],
|
||||
)
|
||||
self.compute_noise_level(betas)
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
# format data
|
||||
x = batch["input"]
|
||||
y = batch["waveform"]
|
||||
|
||||
# set noise scale
|
||||
noise, x_noisy, noise_scale = self.compute_y_n(y)
|
||||
|
||||
# forward pass
|
||||
noise_hat = self.forward(x_noisy, x, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
return {"model_output": noise_hat}, {"loss": loss}
|
||||
|
||||
def train_log( # pylint: disable=no-self-use
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log( # pylint: disable=no-self-use
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
|
||||
# setup noise schedule and inference
|
||||
ap = assets["audio_processor"]
|
||||
noise_schedule = self.config["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
samples = test_loader.dataset.load_test_samples(1)
|
||||
for sample in samples:
|
||||
x = sample[0]
|
||||
x = x[None, :, :].to(next(self.parameters()).device)
|
||||
y = sample[1]
|
||||
y = y[None, :]
|
||||
# compute voice
|
||||
y_pred = self.inference(x)
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_pred, y, ap, "test")
|
||||
# Sample audio
|
||||
sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
|
||||
return figures, {"test/audio": sample_voice}
|
||||
|
||||
def get_optimizer(self):
|
||||
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
|
||||
|
||||
def get_scheduler(self, optimizer):
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
|
||||
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
return torch.nn.L1Loss()
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: Dict) -> Dict:
|
||||
# return a whole audio segment
|
||||
m, y = batch[0], batch[1]
|
||||
y = y.unsqueeze(1)
|
||||
return {"input": m, "waveform": y}
|
||||
|
||||
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
|
||||
ap = assets["audio_processor"]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=samples,
|
||||
seq_len=self.config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=self.config.pad_short,
|
||||
conv_pad=self.config.conv_pad,
|
||||
is_training=not is_eval,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.batch_size,
|
||||
shuffle=num_gpus <= 1,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=unused-argument
|
||||
noise_schedule = self.config["train_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "WavegradConfig"):
|
||||
return Wavegrad(config)
|
||||
@@ -0,0 +1,646 @@
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
|
||||
|
||||
|
||||
def stream(string, variables):
|
||||
sys.stdout.write(f"\r{string}" % variables)
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
# relates https://github.com/pytorch/pytorch/issues/42305
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
||||
self.batch_norm1 = nn.BatchNorm1d(dims)
|
||||
self.batch_norm2 = nn.BatchNorm1d(dims)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.batch_norm1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.batch_norm2(x)
|
||||
return x + residual
|
||||
|
||||
|
||||
class MelResNet(nn.Module):
|
||||
def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
||||
super().__init__()
|
||||
k_size = pad * 2 + 1
|
||||
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
|
||||
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_res_blocks):
|
||||
self.layers.append(ResBlock(compute_dims))
|
||||
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.batch_norm(x)
|
||||
x = F.relu(x)
|
||||
for f in self.layers:
|
||||
x = f(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class Stretch2d(nn.Module):
|
||||
def __init__(self, x_scale, y_scale):
|
||||
super().__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
x = x.unsqueeze(-1).unsqueeze(3)
|
||||
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
|
||||
return x.view(b, c, h * self.y_scale, w * self.x_scale)
|
||||
|
||||
|
||||
class UpsampleNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feat_dims,
|
||||
upsample_scales,
|
||||
compute_dims,
|
||||
num_res_blocks,
|
||||
res_out_dims,
|
||||
pad,
|
||||
use_aux_net,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_scale = np.cumproduct(upsample_scales)[-1]
|
||||
self.indent = pad * self.total_scale
|
||||
self.use_aux_net = use_aux_net
|
||||
if use_aux_net:
|
||||
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
||||
self.resnet_stretch = Stretch2d(self.total_scale, 1)
|
||||
self.up_layers = nn.ModuleList()
|
||||
for scale in upsample_scales:
|
||||
k_size = (1, scale * 2 + 1)
|
||||
padding = (0, scale)
|
||||
stretch = Stretch2d(scale, 1)
|
||||
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
|
||||
conv.weight.data.fill_(1.0 / k_size[1])
|
||||
self.up_layers.append(stretch)
|
||||
self.up_layers.append(conv)
|
||||
|
||||
def forward(self, m):
|
||||
if self.use_aux_net:
|
||||
aux = self.resnet(m).unsqueeze(1)
|
||||
aux = self.resnet_stretch(aux)
|
||||
aux = aux.squeeze(1)
|
||||
aux = aux.transpose(1, 2)
|
||||
else:
|
||||
aux = None
|
||||
m = m.unsqueeze(1)
|
||||
for f in self.up_layers:
|
||||
m = f(m)
|
||||
m = m.squeeze(1)[:, :, self.indent : -self.indent]
|
||||
return m.transpose(1, 2), aux
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.pad = pad
|
||||
self.indent = pad * scale
|
||||
self.use_aux_net = use_aux_net
|
||||
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
||||
|
||||
def forward(self, m):
|
||||
if self.use_aux_net:
|
||||
aux = self.resnet(m)
|
||||
aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True)
|
||||
aux = aux.transpose(1, 2)
|
||||
else:
|
||||
aux = None
|
||||
m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True)
|
||||
m = m[:, :, self.indent : -self.indent]
|
||||
m = m * 0.045 # empirically found
|
||||
|
||||
return m.transpose(1, 2), aux
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavernnArgs(Coqpit):
|
||||
"""🐸 WaveRNN model arguments.
|
||||
|
||||
rnn_dims (int):
|
||||
Number of hidden channels in RNN layers. Defaults to 512.
|
||||
fc_dims (int):
|
||||
Number of hidden channels in fully-conntected layers. Defaults to 512.
|
||||
compute_dims (int):
|
||||
Number of hidden channels in the feature ResNet. Defaults to 128.
|
||||
res_out_dim (int):
|
||||
Number of hidden channels in the feature ResNet output. Defaults to 128.
|
||||
num_res_blocks (int):
|
||||
Number of residual blocks in the ResNet. Defaults to 10.
|
||||
use_aux_net (bool):
|
||||
enable/disable the feature ResNet. Defaults to True.
|
||||
use_upsample_net (bool):
|
||||
enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True.
|
||||
upsample_factors (list):
|
||||
Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```.
|
||||
mode (str):
|
||||
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
|
||||
Gaussian Distribution and `bits` for quantized bits as the model's output.
|
||||
mulaw (bool):
|
||||
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
|
||||
to `True`.
|
||||
pad (int):
|
||||
Padding applied to the input feature frames against the convolution layers of the feature network.
|
||||
Defaults to 2.
|
||||
"""
|
||||
|
||||
rnn_dims: int = 512
|
||||
fc_dims: int = 512
|
||||
compute_dims: int = 128
|
||||
res_out_dims: int = 128
|
||||
num_res_blocks: int = 10
|
||||
use_aux_net: bool = True
|
||||
use_upsample_net: bool = True
|
||||
upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8])
|
||||
mode: str = "mold" # mold [string], gauss [string], bits [int]
|
||||
mulaw: bool = True # apply mulaw if mode is bits
|
||||
pad: int = 2
|
||||
feat_dims: int = 80
|
||||
|
||||
|
||||
class Wavernn(BaseVocoder):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""🐸 WaveRNN model.
|
||||
Original paper - https://arxiv.org/abs/1802.08435
|
||||
Official implementation - https://github.com/fatchord/WaveRNN
|
||||
|
||||
Args:
|
||||
config (Coqpit): [description]
|
||||
|
||||
Raises:
|
||||
RuntimeError: [description]
|
||||
|
||||
Examples:
|
||||
>>> from TTS.vocoder.configs import WavernnConfig
|
||||
>>> config = WavernnConfig()
|
||||
>>> model = Wavernn(config)
|
||||
|
||||
Paper Abstract:
|
||||
Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to
|
||||
both estimating the data distribution and generating high-quality samples. Efficient sampling for this
|
||||
class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we
|
||||
describe a set of general techniques for reducing sampling time while maintaining high output quality.
|
||||
We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that
|
||||
matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it
|
||||
possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight
|
||||
pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of
|
||||
parameters, large sparse networks perform better than small dense networks and this relationship holds for
|
||||
sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample
|
||||
high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on
|
||||
subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple
|
||||
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
|
||||
orthogonal method for increasing sampling efficiency.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if isinstance(self.args.mode, int):
|
||||
self.n_classes = 2**self.args.mode
|
||||
elif self.args.mode == "mold":
|
||||
self.n_classes = 3 * 10
|
||||
elif self.args.mode == "gauss":
|
||||
self.n_classes = 2
|
||||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||
|
||||
self.ap = AudioProcessor(**config.audio.to_dict())
|
||||
self.aux_dims = self.args.res_out_dims // 4
|
||||
|
||||
if self.args.use_upsample_net:
|
||||
assert (
|
||||
np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length
|
||||
), " [!] upsample scales needs to be equal to hop_length"
|
||||
self.upsample = UpsampleNetwork(
|
||||
self.args.feat_dims,
|
||||
self.args.upsample_factors,
|
||||
self.args.compute_dims,
|
||||
self.args.num_res_blocks,
|
||||
self.args.res_out_dims,
|
||||
self.args.pad,
|
||||
self.args.use_aux_net,
|
||||
)
|
||||
else:
|
||||
self.upsample = Upsample(
|
||||
config.audio.hop_length,
|
||||
self.args.pad,
|
||||
self.args.num_res_blocks,
|
||||
self.args.feat_dims,
|
||||
self.args.compute_dims,
|
||||
self.args.res_out_dims,
|
||||
self.args.use_aux_net,
|
||||
)
|
||||
if self.args.use_aux_net:
|
||||
self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims)
|
||||
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims)
|
||||
self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims)
|
||||
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
|
||||
else:
|
||||
self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims)
|
||||
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims)
|
||||
self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims)
|
||||
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
|
||||
|
||||
def forward(self, x, mels):
|
||||
bsize = x.size(0)
|
||||
h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
|
||||
h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
|
||||
mels, aux = self.upsample(mels)
|
||||
|
||||
if self.args.use_aux_net:
|
||||
aux_idx = [self.aux_dims * i for i in range(5)]
|
||||
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
|
||||
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
|
||||
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
|
||||
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
|
||||
|
||||
x = (
|
||||
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
|
||||
if self.args.use_aux_net
|
||||
else torch.cat([x.unsqueeze(-1), mels], dim=2)
|
||||
)
|
||||
x = self.I(x)
|
||||
res = x
|
||||
self.rnn1.flatten_parameters()
|
||||
x, _ = self.rnn1(x, h1)
|
||||
|
||||
x = x + res
|
||||
res = x
|
||||
x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x
|
||||
self.rnn2.flatten_parameters()
|
||||
x, _ = self.rnn2(x, h2)
|
||||
|
||||
x = x + res
|
||||
x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc2(x))
|
||||
return self.fc3(x)
|
||||
|
||||
def inference(self, mels, batched=None, target=None, overlap=None):
|
||||
self.eval()
|
||||
output = []
|
||||
start = time.time()
|
||||
rnn1 = self.get_gru_cell(self.rnn1)
|
||||
rnn2 = self.get_gru_cell(self.rnn2)
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(mels, np.ndarray):
|
||||
mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
|
||||
|
||||
if mels.ndim == 2:
|
||||
mels = mels.unsqueeze(0)
|
||||
wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length
|
||||
|
||||
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both")
|
||||
mels, aux = self.upsample(mels.transpose(1, 2))
|
||||
|
||||
if batched:
|
||||
mels = self.fold_with_overlap(mels, target, overlap)
|
||||
if aux is not None:
|
||||
aux = self.fold_with_overlap(aux, target, overlap)
|
||||
|
||||
b_size, seq_len, _ = mels.size()
|
||||
|
||||
h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
|
||||
h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
|
||||
x = torch.zeros(b_size, 1).type_as(mels)
|
||||
|
||||
if self.args.use_aux_net:
|
||||
d = self.aux_dims
|
||||
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
|
||||
|
||||
for i in range(seq_len):
|
||||
m_t = mels[:, i, :]
|
||||
|
||||
if self.args.use_aux_net:
|
||||
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
|
||||
|
||||
x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1)
|
||||
x = self.I(x)
|
||||
h1 = rnn1(x, h1)
|
||||
|
||||
x = x + h1
|
||||
inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x
|
||||
h2 = rnn2(inp, h2)
|
||||
|
||||
x = x + h2
|
||||
x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc2(x))
|
||||
|
||||
logits = self.fc3(x)
|
||||
|
||||
if self.args.mode == "mold":
|
||||
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif self.args.mode == "gauss":
|
||||
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif isinstance(self.args.mode, int):
|
||||
posterior = F.softmax(logits, dim=1)
|
||||
distrib = torch.distributions.Categorical(posterior)
|
||||
|
||||
sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
|
||||
output.append(sample)
|
||||
x = sample.unsqueeze(-1)
|
||||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||
|
||||
if i % 100 == 0:
|
||||
self.gen_display(i, seq_len, b_size, start)
|
||||
|
||||
output = torch.stack(output).transpose(0, 1)
|
||||
output = output.cpu()
|
||||
if batched:
|
||||
output = output.numpy()
|
||||
output = output.astype(np.float64)
|
||||
|
||||
output = self.xfade_and_unfold(output, target, overlap)
|
||||
else:
|
||||
output = output[0]
|
||||
|
||||
if self.args.mulaw and isinstance(self.args.mode, int):
|
||||
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
|
||||
|
||||
# Fade-out at the end to avoid signal cutting out suddenly
|
||||
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
||||
output = output[:wave_len]
|
||||
|
||||
if wave_len > len(fade_out):
|
||||
output[-20 * self.config.audio.hop_length :] *= fade_out
|
||||
|
||||
self.train()
|
||||
return output
|
||||
|
||||
def gen_display(self, i, seq_len, b_size, start):
|
||||
gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
|
||||
realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate
|
||||
stream(
|
||||
"%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ",
|
||||
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
|
||||
)
|
||||
|
||||
def fold_with_overlap(self, x, target, overlap):
|
||||
"""Fold the tensor with overlap for quick batched inference.
|
||||
Overlap will be used for crossfading in xfade_and_unfold()
|
||||
Args:
|
||||
x (tensor) : Upsampled conditioning features.
|
||||
shape=(1, timesteps, features)
|
||||
target (int) : Target timesteps for each index of batch
|
||||
overlap (int) : Timesteps for both xfade and rnn warmup
|
||||
Return:
|
||||
(tensor) : shape=(num_folds, target + 2 * overlap, features)
|
||||
Details:
|
||||
x = [[h1, h2, ... hn]]
|
||||
Where each h is a vector of conditioning features
|
||||
Eg: target=2, overlap=1 with x.size(1)=10
|
||||
folded = [[h1, h2, h3, h4],
|
||||
[h4, h5, h6, h7],
|
||||
[h7, h8, h9, h10]]
|
||||
"""
|
||||
|
||||
_, total_len, features = x.size()
|
||||
|
||||
# Calculate variables needed
|
||||
num_folds = (total_len - overlap) // (target + overlap)
|
||||
extended_len = num_folds * (overlap + target) + overlap
|
||||
remaining = total_len - extended_len
|
||||
|
||||
# Pad if some time steps poking out
|
||||
if remaining != 0:
|
||||
num_folds += 1
|
||||
padding = target + 2 * overlap - remaining
|
||||
x = self.pad_tensor(x, padding, side="after")
|
||||
|
||||
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
|
||||
|
||||
# Get the values for the folded tensor
|
||||
for i in range(num_folds):
|
||||
start = i * (target + overlap)
|
||||
end = start + target + 2 * overlap
|
||||
folded[i] = x[:, start:end, :]
|
||||
|
||||
return folded
|
||||
|
||||
@staticmethod
|
||||
def get_gru_cell(gru):
|
||||
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
||||
gru_cell.weight_hh.data = gru.weight_hh_l0.data
|
||||
gru_cell.weight_ih.data = gru.weight_ih_l0.data
|
||||
gru_cell.bias_hh.data = gru.bias_hh_l0.data
|
||||
gru_cell.bias_ih.data = gru.bias_ih_l0.data
|
||||
return gru_cell
|
||||
|
||||
@staticmethod
|
||||
def pad_tensor(x, pad, side="both"):
|
||||
# NB - this is just a quick method i need right now
|
||||
# i.e., it won't generalise to other shapes/dims
|
||||
b, t, c = x.size()
|
||||
total = t + 2 * pad if side == "both" else t + pad
|
||||
padded = torch.zeros(b, total, c).to(x.device)
|
||||
if side in ("before", "both"):
|
||||
padded[:, pad : pad + t, :] = x
|
||||
elif side == "after":
|
||||
padded[:, :t, :] = x
|
||||
return padded
|
||||
|
||||
@staticmethod
|
||||
def xfade_and_unfold(y, target, overlap):
|
||||
"""Applies a crossfade and unfolds into a 1d array.
|
||||
Args:
|
||||
y (ndarry) : Batched sequences of audio samples
|
||||
shape=(num_folds, target + 2 * overlap)
|
||||
dtype=np.float64
|
||||
overlap (int) : Timesteps for both xfade and rnn warmup
|
||||
Return:
|
||||
(ndarry) : audio samples in a 1d array
|
||||
shape=(total_len)
|
||||
dtype=np.float64
|
||||
Details:
|
||||
y = [[seq1],
|
||||
[seq2],
|
||||
[seq3]]
|
||||
Apply a gain envelope at both ends of the sequences
|
||||
y = [[seq1_in, seq1_target, seq1_out],
|
||||
[seq2_in, seq2_target, seq2_out],
|
||||
[seq3_in, seq3_target, seq3_out]]
|
||||
Stagger and add up the groups of samples:
|
||||
[seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
|
||||
"""
|
||||
|
||||
num_folds, length = y.shape
|
||||
target = length - 2 * overlap
|
||||
total_len = num_folds * (target + overlap) + overlap
|
||||
|
||||
# Need some silence for the rnn warmup
|
||||
silence_len = overlap // 2
|
||||
fade_len = overlap - silence_len
|
||||
silence = np.zeros((silence_len), dtype=np.float64)
|
||||
|
||||
# Equal power crossfade
|
||||
t = np.linspace(-1, 1, fade_len, dtype=np.float64)
|
||||
fade_in = np.sqrt(0.5 * (1 + t))
|
||||
fade_out = np.sqrt(0.5 * (1 - t))
|
||||
|
||||
# Concat the silence to the fades
|
||||
fade_in = np.concatenate([silence, fade_in])
|
||||
fade_out = np.concatenate([fade_out, silence])
|
||||
|
||||
# Apply the gain to the overlap samples
|
||||
y[:, :overlap] *= fade_in
|
||||
y[:, -overlap:] *= fade_out
|
||||
|
||||
unfolded = np.zeros((total_len), dtype=np.float64)
|
||||
|
||||
# Loop to add up all the samples
|
||||
for i in range(num_folds):
|
||||
start = i * (target + overlap)
|
||||
end = start + target + 2 * overlap
|
||||
unfolded[start:end] += y[i]
|
||||
|
||||
return unfolded
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
mels = batch["input"]
|
||||
waveform = batch["waveform"]
|
||||
waveform_coarse = batch["waveform_coarse"]
|
||||
|
||||
y_hat = self.forward(waveform, mels)
|
||||
if isinstance(self.args.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
waveform_coarse = waveform_coarse.float()
|
||||
waveform_coarse = waveform_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss_dict = criterion(y_hat, waveform_coarse)
|
||||
return {"model_output": y_hat}, loss_dict
|
||||
|
||||
def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
@torch.no_grad()
|
||||
def test(
|
||||
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, Dict]:
|
||||
ap = self.ap
|
||||
figures = {}
|
||||
audios = {}
|
||||
samples = test_loader.dataset.load_test_samples(1)
|
||||
for idx, sample in enumerate(samples):
|
||||
x = torch.FloatTensor(sample[0])
|
||||
x = x.to(next(self.parameters()).device)
|
||||
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
|
||||
x_hat = ap.melspectrogram(y_hat)
|
||||
figures.update(
|
||||
{
|
||||
f"test_{idx}/ground_truth": plot_spectrogram(x.T),
|
||||
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
|
||||
}
|
||||
)
|
||||
audios.update({f"test_{idx}/audio": y_hat})
|
||||
# audios.update({f"real_{idx}/audio": y_hat})
|
||||
return figures, audios
|
||||
|
||||
def test_log(
|
||||
self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
figures, audios = outputs
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: Dict) -> Dict:
|
||||
waveform = batch[0]
|
||||
mels = batch[1]
|
||||
waveform_coarse = batch[2]
|
||||
return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse}
|
||||
|
||||
def get_data_loader( # pylint: disable=no-self-use
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: True,
|
||||
samples: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
):
|
||||
ap = self.ap
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=samples,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=config.model_args.pad,
|
||||
mode=config.model_args.mode,
|
||||
mulaw=config.model_args.mulaw,
|
||||
is_training=not is_eval,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_eval else config.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
collate_fn=dataset.collate,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_criterion(self):
|
||||
# define train functions
|
||||
return WaveRNNLoss(self.args.mode)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "WavernnConfig"):
|
||||
return Wavernn(config)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,154 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions.normal import Normal
|
||||
|
||||
|
||||
def gaussian_loss(y_hat, y, log_std_min=-7.0):
|
||||
assert y_hat.dim() == 3
|
||||
assert y_hat.size(2) == 2
|
||||
mean = y_hat[:, :, :1]
|
||||
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
||||
# TODO: replace with pytorch dist
|
||||
log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)))
|
||||
return log_probs.squeeze().mean()
|
||||
|
||||
|
||||
def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0):
|
||||
assert y_hat.size(2) == 2
|
||||
mean = y_hat[:, :, :1]
|
||||
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
||||
dist = Normal(
|
||||
mean,
|
||||
torch.exp(log_std),
|
||||
)
|
||||
sample = dist.sample()
|
||||
sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor)
|
||||
del dist
|
||||
return sample
|
||||
|
||||
|
||||
def log_sum_exp(x):
|
||||
"""numerically stable log_sum_exp implementation that prevents overflow"""
|
||||
# TF ordering
|
||||
axis = len(x.size()) - 1
|
||||
m, _ = torch.max(x, dim=axis)
|
||||
m2, _ = torch.max(x, dim=axis, keepdim=True)
|
||||
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
|
||||
|
||||
|
||||
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
||||
def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True):
|
||||
if log_scale_min is None:
|
||||
log_scale_min = float(np.log(1e-14))
|
||||
y_hat = y_hat.permute(0, 2, 1)
|
||||
assert y_hat.dim() == 3
|
||||
assert y_hat.size(1) % 3 == 0
|
||||
nr_mix = y_hat.size(1) // 3
|
||||
|
||||
# (B x T x C)
|
||||
y_hat = y_hat.transpose(1, 2)
|
||||
|
||||
# unpack parameters. (B, T, num_mixtures) x 3
|
||||
logit_probs = y_hat[:, :, :nr_mix]
|
||||
means = y_hat[:, :, nr_mix : 2 * nr_mix]
|
||||
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
|
||||
|
||||
# B x T x 1 -> B x T x num_mixtures
|
||||
y = y.expand_as(means)
|
||||
|
||||
centered_y = y - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
|
||||
cdf_plus = torch.sigmoid(plus_in)
|
||||
min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
|
||||
cdf_min = torch.sigmoid(min_in)
|
||||
|
||||
# log probability for edge case of 0 (before scaling)
|
||||
# equivalent: torch.log(F.sigmoid(plus_in))
|
||||
log_cdf_plus = plus_in - F.softplus(plus_in)
|
||||
|
||||
# log probability for edge case of 255 (before scaling)
|
||||
# equivalent: (1 - F.sigmoid(min_in)).log()
|
||||
log_one_minus_cdf_min = -F.softplus(min_in)
|
||||
|
||||
# probability for all other cases
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
mid_in = inv_stdv * centered_y
|
||||
# log probability in the center of the bin, to be used in extreme cases
|
||||
# (not actually used in our code)
|
||||
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
|
||||
|
||||
# tf equivalent
|
||||
|
||||
# log_probs = tf.where(x < -0.999, log_cdf_plus,
|
||||
# tf.where(x > 0.999, log_one_minus_cdf_min,
|
||||
# tf.where(cdf_delta > 1e-5,
|
||||
# tf.log(tf.maximum(cdf_delta, 1e-12)),
|
||||
# log_pdf_mid - np.log(127.5))))
|
||||
|
||||
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
||||
# for num_classes=65536 case? 1e-7? not sure..
|
||||
inner_inner_cond = (cdf_delta > 1e-5).float()
|
||||
|
||||
inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * (
|
||||
log_pdf_mid - np.log((num_classes - 1) / 2)
|
||||
)
|
||||
inner_cond = (y > 0.999).float()
|
||||
inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
|
||||
cond = (y < -0.999).float()
|
||||
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
|
||||
|
||||
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
||||
|
||||
if reduce:
|
||||
return -torch.mean(log_sum_exp(log_probs))
|
||||
return -log_sum_exp(log_probs).unsqueeze(-1)
|
||||
|
||||
|
||||
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
||||
"""
|
||||
Sample from discretized mixture of logistic distributions
|
||||
Args:
|
||||
y (Tensor): :math:`[B, C, T]`
|
||||
log_scale_min (float): Log scale minimum value
|
||||
Returns:
|
||||
Tensor: sample in range of [-1, 1].
|
||||
"""
|
||||
if log_scale_min is None:
|
||||
log_scale_min = float(np.log(1e-14))
|
||||
assert y.size(1) % 3 == 0
|
||||
nr_mix = y.size(1) // 3
|
||||
|
||||
# B x T x C
|
||||
y = y.transpose(1, 2)
|
||||
logit_probs = y[:, :, :nr_mix]
|
||||
|
||||
# sample mixture indicator from softmax
|
||||
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||
temp = logit_probs.data - torch.log(-torch.log(temp))
|
||||
_, argmax = temp.max(dim=-1)
|
||||
|
||||
# (B, T) -> (B, T, nr_mix)
|
||||
one_hot = to_one_hot(argmax, nr_mix)
|
||||
# select logistic parameters
|
||||
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
|
||||
log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
|
||||
# sample from logistic & clip to interval
|
||||
# we don't actually round to the nearest 8bit value when sampling
|
||||
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
|
||||
|
||||
x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, n, fill_with=1.0):
|
||||
# we perform one hot encore with respect to the last axis
|
||||
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor)
|
||||
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
||||
return one_hot
|
||||
@@ -0,0 +1,72 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def interpolate_vocoder_input(scale_factor, spec):
|
||||
"""Interpolate spectrogram by the scale factor.
|
||||
It is mainly used to match the sampling rates of
|
||||
the tts and vocoder models.
|
||||
|
||||
Args:
|
||||
scale_factor (float): scale factor to interpolate the spectrogram
|
||||
spec (np.array): spectrogram to be interpolated
|
||||
|
||||
Returns:
|
||||
torch.tensor: interpolated spectrogram.
|
||||
"""
|
||||
print(" > before interpolation :", spec.shape)
|
||||
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
|
||||
spec = torch.nn.functional.interpolate(
|
||||
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
print(" > after interpolation :", spec.shape)
|
||||
return spec
|
||||
|
||||
|
||||
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
|
||||
"""Plot the predicted and the real waveform and their spectrograms.
|
||||
|
||||
Args:
|
||||
y_hat (torch.tensor): Predicted waveform.
|
||||
y (torch.tensor): Real waveform.
|
||||
ap (AudioProcessor): Audio processor used to process the waveform.
|
||||
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict: output figures keyed by the name of the figures.
|
||||
""" """Plot vocoder model results"""
|
||||
if name_prefix is None:
|
||||
name_prefix = ""
|
||||
|
||||
# select an instance from batch
|
||||
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
||||
y = y[0].squeeze().detach().cpu().numpy()
|
||||
|
||||
spec_fake = ap.melspectrogram(y_hat).T
|
||||
spec_real = ap.melspectrogram(y).T
|
||||
spec_diff = np.abs(spec_fake - spec_real)
|
||||
|
||||
# plot figure and save it
|
||||
fig_wave = plt.figure()
|
||||
plt.subplot(2, 1, 1)
|
||||
plt.plot(y)
|
||||
plt.title("groundtruth speech")
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(y_hat)
|
||||
plt.title("generated speech")
|
||||
plt.tight_layout()
|
||||
plt.close()
|
||||
|
||||
figures = {
|
||||
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
|
||||
name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
|
||||
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
|
||||
name_prefix + "speech_comparison": fig_wave,
|
||||
}
|
||||
return figures
|
||||
Reference in New Issue
Block a user