Add files via upload
This commit is contained in:
Binary file not shown.
@@ -0,0 +1,278 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.vc.configs.shared_configs import BaseVCConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreeVCAudioConfig(Coqpit):
|
||||
"""Audio configuration
|
||||
|
||||
Args:
|
||||
max_wav_value (float):
|
||||
The maximum value of the waveform.
|
||||
|
||||
input_sample_rate (int):
|
||||
The sampling rate of the input waveform.
|
||||
|
||||
output_sample_rate (int):
|
||||
The sampling rate of the output waveform.
|
||||
|
||||
filter_length (int):
|
||||
The length of the filter.
|
||||
|
||||
hop_length (int):
|
||||
The hop length.
|
||||
|
||||
win_length (int):
|
||||
The window length.
|
||||
|
||||
n_mel_channels (int):
|
||||
The number of mel channels.
|
||||
|
||||
mel_fmin (float):
|
||||
The minimum frequency of the mel filterbank.
|
||||
|
||||
mel_fmax (Optional[float]):
|
||||
The maximum frequency of the mel filterbank.
|
||||
"""
|
||||
|
||||
max_wav_value: float = field(default=32768.0)
|
||||
input_sample_rate: int = field(default=16000)
|
||||
output_sample_rate: int = field(default=24000)
|
||||
filter_length: int = field(default=1280)
|
||||
hop_length: int = field(default=320)
|
||||
win_length: int = field(default=1280)
|
||||
n_mel_channels: int = field(default=80)
|
||||
mel_fmin: float = field(default=0.0)
|
||||
mel_fmax: Optional[float] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreeVCArgs(Coqpit):
|
||||
"""FreeVC model arguments
|
||||
|
||||
Args:
|
||||
spec_channels (int):
|
||||
The number of channels in the spectrogram.
|
||||
|
||||
inter_channels (int):
|
||||
The number of channels in the intermediate layers.
|
||||
|
||||
hidden_channels (int):
|
||||
The number of channels in the hidden layers.
|
||||
|
||||
filter_channels (int):
|
||||
The number of channels in the filter layers.
|
||||
|
||||
n_heads (int):
|
||||
The number of attention heads.
|
||||
|
||||
n_layers (int):
|
||||
The number of layers.
|
||||
|
||||
kernel_size (int):
|
||||
The size of the kernel.
|
||||
|
||||
p_dropout (float):
|
||||
The dropout probability.
|
||||
|
||||
resblock (str):
|
||||
The type of residual block.
|
||||
|
||||
resblock_kernel_sizes (List[int]):
|
||||
The kernel sizes for the residual blocks.
|
||||
|
||||
resblock_dilation_sizes (List[List[int]]):
|
||||
The dilation sizes for the residual blocks.
|
||||
|
||||
upsample_rates (List[int]):
|
||||
The upsample rates.
|
||||
|
||||
upsample_initial_channel (int):
|
||||
The number of channels in the initial upsample layer.
|
||||
|
||||
upsample_kernel_sizes (List[int]):
|
||||
The kernel sizes for the upsample layers.
|
||||
|
||||
n_layers_q (int):
|
||||
The number of layers in the quantization network.
|
||||
|
||||
use_spectral_norm (bool):
|
||||
Whether to use spectral normalization.
|
||||
|
||||
gin_channels (int):
|
||||
The number of channels in the global conditioning vector.
|
||||
|
||||
ssl_dim (int):
|
||||
The dimension of the self-supervised learning embedding.
|
||||
|
||||
use_spk (bool):
|
||||
Whether to use external speaker encoder.
|
||||
"""
|
||||
|
||||
spec_channels: int = field(default=641)
|
||||
inter_channels: int = field(default=192)
|
||||
hidden_channels: int = field(default=192)
|
||||
filter_channels: int = field(default=768)
|
||||
n_heads: int = field(default=2)
|
||||
n_layers: int = field(default=6)
|
||||
kernel_size: int = field(default=3)
|
||||
p_dropout: float = field(default=0.1)
|
||||
resblock: str = field(default="1")
|
||||
resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])
|
||||
resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2])
|
||||
upsample_initial_channel: int = field(default=512)
|
||||
upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||
n_layers_q: int = field(default=3)
|
||||
use_spectral_norm: bool = field(default=False)
|
||||
gin_channels: int = field(default=256)
|
||||
ssl_dim: int = field(default=1024)
|
||||
use_spk: bool = field(default=False)
|
||||
num_spks: int = field(default=0)
|
||||
segment_size: int = field(default=8960)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreeVCConfig(BaseVCConfig):
|
||||
"""Defines parameters for FreeVC End2End TTS model.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name. Do not change unless you know what you are doing.
|
||||
|
||||
model_args (FreeVCArgs):
|
||||
Model architecture arguments. Defaults to `FreeVCArgs()`.
|
||||
|
||||
audio (FreeVCAudioConfig):
|
||||
Audio processing configuration. Defaults to `FreeVCAudioConfig()`.
|
||||
|
||||
grad_clip (List):
|
||||
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
||||
|
||||
lr_gen (float):
|
||||
Initial learning rate for the generator. Defaults to 0.0002.
|
||||
|
||||
lr_disc (float):
|
||||
Initial learning rate for the discriminator. Defaults to 0.0002.
|
||||
|
||||
lr_scheduler_gen (str):
|
||||
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||
`ExponentialLR`.
|
||||
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||
|
||||
lr_scheduler_disc (str):
|
||||
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||
`ExponentialLR`.
|
||||
|
||||
lr_scheduler_disc_params (dict):
|
||||
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||
|
||||
scheduler_after_epoch (bool):
|
||||
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
||||
|
||||
optimizer (str):
|
||||
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
||||
`torch.optim.*`. Defaults to `AdamW`.
|
||||
|
||||
kl_loss_alpha (float):
|
||||
Loss weight for KL loss. Defaults to 1.0.
|
||||
|
||||
disc_loss_alpha (float):
|
||||
Loss weight for the discriminator loss. Defaults to 1.0.
|
||||
|
||||
gen_loss_alpha (float):
|
||||
Loss weight for the generator loss. Defaults to 1.0.
|
||||
|
||||
feat_loss_alpha (float):
|
||||
Loss weight for the feature matching loss. Defaults to 1.0.
|
||||
|
||||
mel_loss_alpha (float):
|
||||
Loss weight for the mel loss. Defaults to 45.0.
|
||||
|
||||
return_wav (bool):
|
||||
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
||||
|
||||
compute_linear_spec (bool):
|
||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||
|
||||
use_weighted_sampler (bool):
|
||||
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
||||
|
||||
weighted_sampler_attrs (dict):
|
||||
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
||||
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
||||
|
||||
weighted_sampler_multipliers (dict):
|
||||
Weight each unique value of a key returned by the formatter for weighted sampling.
|
||||
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
||||
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
||||
|
||||
r (int):
|
||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||
|
||||
add_blank (bool):
|
||||
If true, a blank token is added in between every character. Defaults to `True`.
|
||||
|
||||
test_sentences (List[List]):
|
||||
List of sentences with speaker and language information to be used for testing.
|
||||
|
||||
language_ids_file (str):
|
||||
Path to the language ids file.
|
||||
|
||||
use_language_embedding (bool):
|
||||
If true, language embedding is used. Defaults to `False`.
|
||||
|
||||
Note:
|
||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
>>> config = FreeVCConfig()
|
||||
"""
|
||||
|
||||
model: str = "freevc"
|
||||
# model specific params
|
||||
model_args: FreeVCArgs = field(default_factory=FreeVCArgs)
|
||||
audio: FreeVCAudioConfig = field(default_factory=FreeVCAudioConfig)
|
||||
|
||||
# optimizer
|
||||
# TODO with training support
|
||||
|
||||
# loss params
|
||||
# TODO with training support
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
compute_linear_spec: bool = True
|
||||
|
||||
# sampler params
|
||||
use_weighted_sampler: bool = False # TODO: move it to the base config
|
||||
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
||||
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
||||
|
||||
# overrides
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
||||
# multi-speaker settings
|
||||
# use speaker embedding layer
|
||||
num_speakers: int = 0
|
||||
speakers_file: str = None
|
||||
speaker_embedding_channels: int = 256
|
||||
|
||||
# use d-vectors
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: List[str] = None
|
||||
d_vector_dim: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
for key, val in self.model_args.items():
|
||||
if hasattr(self, key):
|
||||
self[key] = val
|
||||
@@ -0,0 +1,155 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from coqpit import Coqpit, check_argument
|
||||
|
||||
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVCConfig(BaseTrainingConfig):
|
||||
"""Shared parameters among all the tts models.
|
||||
|
||||
Args:
|
||||
|
||||
audio (BaseAudioConfig):
|
||||
Audio processor config object instance.
|
||||
|
||||
batch_group_size (int):
|
||||
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
|
||||
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
|
||||
prevent using the same batches for each epoch.
|
||||
|
||||
loss_masking (bool):
|
||||
enable / disable masking loss values against padded segments of samples in a batch.
|
||||
|
||||
min_text_len (int):
|
||||
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
|
||||
|
||||
max_text_len (int):
|
||||
Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
|
||||
|
||||
min_audio_len (int):
|
||||
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
|
||||
|
||||
max_audio_len (int):
|
||||
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
|
||||
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
|
||||
OOM error in training. Defaults to float("inf").
|
||||
|
||||
compute_f0 (int):
|
||||
(Not in use yet).
|
||||
|
||||
compute_energy (int):
|
||||
(Not in use yet).
|
||||
|
||||
compute_linear_spec (bool):
|
||||
If True data loader computes and returns linear spectrograms alongside the other data.
|
||||
|
||||
precompute_num_workers (int):
|
||||
Number of workers to precompute features. Defaults to 0.
|
||||
|
||||
use_noise_augment (bool):
|
||||
Augment the input audio with random noise.
|
||||
|
||||
start_by_longest (bool):
|
||||
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
|
||||
Defaults to False.
|
||||
|
||||
shuffle (bool):
|
||||
If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True.
|
||||
|
||||
drop_last (bool):
|
||||
If True, the data loader will drop the last batch if it is not complete. It helps to prevent
|
||||
issues that emerge from the partial batch statistics. Defaults to True.
|
||||
|
||||
add_blank (bool):
|
||||
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||
of slower run-time due to the longer input sequence.
|
||||
|
||||
datasets (List[BaseDatasetConfig]):
|
||||
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||
for training.
|
||||
|
||||
optimizer (str):
|
||||
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||
Defaults to ``.
|
||||
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
|
||||
lr_scheduler (str):
|
||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||
`TTS.utils.training`. Defaults to ``.
|
||||
|
||||
lr_scheduler_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||
|
||||
test_sentences (List[str]):
|
||||
List of sentences to be used at testing. Defaults to '[]'
|
||||
|
||||
eval_split_max_size (int):
|
||||
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
|
||||
|
||||
eval_split_size (float):
|
||||
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
||||
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
||||
|
||||
use_speaker_weighted_sampler (bool):
|
||||
Enable / Disable the batch balancer by speaker. Defaults to ```False```.
|
||||
|
||||
speaker_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the speaker sampler weights. Defaults to ```1.0```.
|
||||
|
||||
use_language_weighted_sampler (bool):
|
||||
Enable / Disable the batch balancer by language. Defaults to ```False```.
|
||||
|
||||
language_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
|
||||
|
||||
use_length_weighted_sampler (bool):
|
||||
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
|
||||
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
|
||||
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.
|
||||
|
||||
length_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
# training params
|
||||
batch_group_size: int = 0
|
||||
loss_masking: bool = None
|
||||
# dataloading
|
||||
min_audio_len: int = 1
|
||||
max_audio_len: int = float("inf")
|
||||
min_text_len: int = 1
|
||||
max_text_len: int = float("inf")
|
||||
compute_f0: bool = False
|
||||
compute_energy: bool = False
|
||||
compute_linear_spec: bool = False
|
||||
precompute_num_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
start_by_longest: bool = False
|
||||
shuffle: bool = False
|
||||
drop_last: bool = False
|
||||
# dataset
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
optimizer: str = "radam"
|
||||
optimizer_params: dict = None
|
||||
# scheduler
|
||||
lr_scheduler: str = None
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda: [])
|
||||
# evaluation
|
||||
eval_split_max_size: int = None
|
||||
eval_split_size: float = 0.01
|
||||
# weighted samplers
|
||||
use_speaker_weighted_sampler: bool = False
|
||||
speaker_weighted_sampler_alpha: float = 1.0
|
||||
use_language_weighted_sampler: bool = False
|
||||
language_weighted_sampler_alpha: float = 1.0
|
||||
use_length_weighted_sampler: bool = False
|
||||
length_weighted_sampler_alpha: float = 1.0
|
||||
@@ -0,0 +1,17 @@
|
||||
import importlib
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC":
|
||||
print(" > Using model: {}".format(config.model))
|
||||
# fetch the right model implementation.
|
||||
if "model" in config and config["model"].lower() == "freevc":
|
||||
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC
|
||||
model = MyModel.init_from_config(config, samples)
|
||||
return model
|
||||
Binary file not shown.
@@ -0,0 +1,429 @@
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseVC(BaseTrainerModel):
|
||||
"""Base `vc` class. Every new `vc` model must inherit this.
|
||||
|
||||
It defines common `vc` specific functions on top of `Model` implementation.
|
||||
"""
|
||||
|
||||
MODEL_TYPE = "vc"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor",
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.speaker_manager = speaker_manager
|
||||
self.language_manager = language_manager
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
||||
|
||||
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
||||
|
||||
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
||||
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
self.config = config
|
||||
self.args = config.model_args
|
||||
elif "Args" in config.__class__.__name__:
|
||||
self.args = config
|
||||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
||||
This implementation yields 3 possible outcomes:
|
||||
|
||||
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||
`config.d_vector_dim` or 512.
|
||||
|
||||
You can override this function for new models.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
# set number of speakers
|
||||
if self.speaker_manager is not None:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
elif hasattr(config, "num_speakers"):
|
||||
self.num_speakers = config.num_speakers
|
||||
|
||||
# set ultimate speaker embedding size
|
||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||
)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||
|
||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||
if hasattr(self.config, "model_args"):
|
||||
config = self.config.model_args
|
||||
else:
|
||||
config = self.config
|
||||
|
||||
# extract speaker and language info
|
||||
text, speaker_name, style_wav, language_name = None, None, None, None
|
||||
|
||||
if isinstance(sentence_info, list):
|
||||
if len(sentence_info) == 1:
|
||||
text = sentence_info[0]
|
||||
elif len(sentence_info) == 2:
|
||||
text, speaker_name = sentence_info
|
||||
elif len(sentence_info) == 3:
|
||||
text, speaker_name, style_wav = sentence_info
|
||||
elif len(sentence_info) == 4:
|
||||
text, speaker_name, style_wav, language_name = sentence_info
|
||||
else:
|
||||
text = sentence_info
|
||||
|
||||
# get speaker id/d_vector
|
||||
speaker_id, d_vector, language_id = None, None, None
|
||||
if self.speaker_manager is not None:
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embedding()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.name_to_id[speaker_name]
|
||||
|
||||
# get language id
|
||||
if self.language_manager is not None and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.name_to_id[language_name]
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"speaker_id": speaker_id,
|
||||
"style_wav": style_wav,
|
||||
"d_vector": d_vector,
|
||||
"language_id": language_id,
|
||||
}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `VCDataset`.
|
||||
|
||||
You must override this if you use a custom dataset.
|
||||
|
||||
Args:
|
||||
batch (Dict): [description]
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch["token_id"]
|
||||
text_lengths = batch["token_id_lengths"]
|
||||
speaker_names = batch["speaker_names"]
|
||||
linear_input = batch["linear"]
|
||||
mel_input = batch["mel"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
item_idx = batch["item_idxs"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
attn_mask = batch["attns"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
energy = batch["energy"]
|
||||
language_ids = batch["language_ids"]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
# compute durations from attention masks
|
||||
durations = None
|
||||
if attn_mask is not None:
|
||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
dur[c_idxs] = counts
|
||||
# smooth the durations and set any 0 duration to 1
|
||||
# by cutting off from the largest duration indeces.
|
||||
extra_frames = dur.sum() - mel_lengths[idx]
|
||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||
dur[largest_idxs] -= 1
|
||||
assert (
|
||||
dur.sum() == mel_lengths[idx]
|
||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
|
||||
# set stop targets wrt reduction factor
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
|
||||
|
||||
return {
|
||||
"text_input": text_input,
|
||||
"text_lengths": text_lengths,
|
||||
"speaker_names": speaker_names,
|
||||
"mel_input": mel_input,
|
||||
"mel_lengths": mel_lengths,
|
||||
"linear_input": linear_input,
|
||||
"stop_targets": stop_targets,
|
||||
"stop_target_lengths": stop_target_lengths,
|
||||
"attn_mask": attn_mask,
|
||||
"durations": durations,
|
||||
"speaker_ids": speaker_ids,
|
||||
"d_vectors": d_vectors,
|
||||
"max_text_length": float(max_text_length),
|
||||
"max_spec_length": float(max_spec_length),
|
||||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
"pitch": pitch,
|
||||
"energy": energy,
|
||||
"language_ids": language_ids,
|
||||
"audio_unique_names": batch["audio_unique_names"],
|
||||
}
|
||||
|
||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
|
||||
weights = None
|
||||
data_items = dataset.samples
|
||||
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Language weighted sampler with alpha:", alpha)
|
||||
weights = get_language_balancer_weights(data_items) * alpha
|
||||
|
||||
if getattr(config, "use_speaker_weighted_sampler", False):
|
||||
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Speaker weighted sampler with alpha:", alpha)
|
||||
if weights is not None:
|
||||
weights += get_speaker_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
weights = get_speaker_balancer_weights(data_items) * alpha
|
||||
|
||||
if getattr(config, "use_length_weighted_sampler", False):
|
||||
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Length weighted sampler with alpha:", alpha)
|
||||
if weights is not None:
|
||||
weights += get_length_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
weights = get_length_balancer_weights(data_items) * alpha
|
||||
|
||||
if weights is not None:
|
||||
sampler = WeightedRandomSampler(weights, len(weights))
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
# sampler for DDP
|
||||
if sampler is None:
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
else: # If a sampler is already defined use this sampler and DDP sampler together
|
||||
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler
|
||||
|
||||
return sampler
|
||||
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
# setup multi-speaker attributes
|
||||
if self.speaker_manager is not None:
|
||||
if hasattr(config, "model_args"):
|
||||
speaker_id_mapping = (
|
||||
self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
|
||||
)
|
||||
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
|
||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||
else:
|
||||
speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
|
||||
else:
|
||||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
||||
# setup multi-lingual attributes
|
||||
if self.language_manager is not None:
|
||||
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
|
||||
else:
|
||||
language_id_mapping = None
|
||||
|
||||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
compute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
compute_energy=config.get("compute_energy", False),
|
||||
energy_cache_path=config.get("energy_cache_path", None),
|
||||
samples=samples,
|
||||
ap=self.ap,
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_text_len=config.min_text_len,
|
||||
max_text_len=config.max_text_len,
|
||||
min_audio_len=config.min_audio_len,
|
||||
max_audio_len=config.max_audio_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
precompute_num_workers=config.precompute_num_workers,
|
||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||
tokenizer=None,
|
||||
start_by_longest=config.start_by_longest,
|
||||
language_id_mapping=language_id_mapping,
|
||||
)
|
||||
|
||||
# wait all the DDP process to be ready
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
dataset.preprocess_samples()
|
||||
|
||||
# get samplers
|
||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=config.shuffle if sampler is None else False, # if there is no other sampler
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def _get_test_aux_input(
|
||||
self,
|
||||
) -> Dict:
|
||||
d_vector = None
|
||||
if self.config.use_d_vector_file:
|
||||
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
|
||||
d_vector = (random.sample(sorted(d_vector), 1),)
|
||||
|
||||
aux_inputs = {
|
||||
"speaker_id": None
|
||||
if not self.config.use_speaker_embedding
|
||||
else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1),
|
||||
"d_vector": d_vector,
|
||||
"style_wav": None, # TODO: handle GST style input
|
||||
}
|
||||
return aux_inputs
|
||||
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `vc` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
||||
Args:
|
||||
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_test_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
if isinstance(sen, list):
|
||||
aux_inputs = self.get_aux_input_from_test_sentences(sen)
|
||||
sen = aux_inputs["text"]
|
||||
outputs_dict = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||
)
|
||||
return test_figures, test_audios
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
|
||||
if self.speaker_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "speakers.pth")
|
||||
self.speaker_manager.save_ids_to_file(output_path)
|
||||
trainer.config.speakers_file = output_path
|
||||
# some models don't have `model_args` set
|
||||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.speakers_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `speakers.pth` is saved to {output_path}.")
|
||||
print(" > `speakers_file` is updated in the config.json.")
|
||||
|
||||
if self.language_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||
self.language_manager.save_ids_to_file(output_path)
|
||||
trainer.config.language_ids_file = output_path
|
||||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.language_ids_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||
print(" > `language_ids_file` is updated in the config.json.")
|
||||
@@ -0,0 +1,562 @@
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
import TTS.vc.modules.freevc.modules as modules
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
from TTS.vc.models.base_vc import BaseVC
|
||||
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
||||
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
|
||||
from TTS.vc.modules.freevc.wavlm import get_wavlm
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ResidualCouplingLayer(
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=gin_channels,
|
||||
mean_only=True,
|
||||
)
|
||||
)
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||
return z, m, logs, x_mask
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
initial_channel,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=0,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_rs.append(fmap_r)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class SpeakerEncoder(torch.nn.Module):
|
||||
def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
|
||||
super(SpeakerEncoder, self).__init__()
|
||||
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, mels):
|
||||
self.lstm.flatten_parameters()
|
||||
_, (hidden, _) = self.lstm(mels)
|
||||
embeds_raw = self.relu(self.linear(hidden[-1]))
|
||||
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
||||
|
||||
def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
|
||||
mel_slices = []
|
||||
for i in range(0, total_frames - partial_frames, partial_hop):
|
||||
mel_range = torch.arange(i, i + partial_frames)
|
||||
mel_slices.append(mel_range)
|
||||
|
||||
return mel_slices
|
||||
|
||||
def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
|
||||
mel_len = mel.size(1)
|
||||
last_mel = mel[:, -partial_frames:]
|
||||
|
||||
if mel_len > partial_frames:
|
||||
mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
|
||||
mels = list(mel[:, s] for s in mel_slices)
|
||||
mels.append(last_mel)
|
||||
mels = torch.stack(tuple(mels), 0).squeeze(1)
|
||||
|
||||
with torch.no_grad():
|
||||
partial_embeds = self(mels)
|
||||
embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
|
||||
# embed = embed / torch.linalg.norm(embed, 2)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
embed = self(last_mel)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
class FreeVC(BaseVC):
|
||||
"""
|
||||
|
||||
Papaer::
|
||||
https://arxiv.org/abs/2210.15418#
|
||||
|
||||
Paper Abstract::
|
||||
Voice conversion (VC) can be achieved by first extracting source content information and target speaker
|
||||
information, and then reconstructing waveform with these information. However, current approaches normally
|
||||
either extract dirty content information with speaker information leaked in, or demand a large amount of
|
||||
annotated data for training. Besides, the quality of reconstructed waveform can be degraded by the
|
||||
mismatch between conversion model and vocoder. In this paper, we adopt the end-to-end framework of VITS for
|
||||
high-quality waveform reconstruction, and propose strategies for clean content information extraction without
|
||||
text annotation. We disentangle content information by imposing an information bottleneck to WavLM features,
|
||||
and propose the spectrogram-resize based data augmentation to improve the purity of extracted content
|
||||
information. Experimental results show that the proposed method outperforms the latest VC models trained with
|
||||
annotated data and has greater robustness.
|
||||
|
||||
Original Code::
|
||||
https://github.com/OlaWod/FreeVC
|
||||
|
||||
Examples:
|
||||
>>> from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
>>> from TTS.vc.models.freevc import FreeVC
|
||||
>>> config = FreeVCConfig()
|
||||
>>> model = FreeVC(config)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
super().__init__(config, None, speaker_manager, None)
|
||||
|
||||
self.init_multispeaker(config)
|
||||
|
||||
self.spec_channels = self.args.spec_channels
|
||||
self.inter_channels = self.args.inter_channels
|
||||
self.hidden_channels = self.args.hidden_channels
|
||||
self.filter_channels = self.args.filter_channels
|
||||
self.n_heads = self.args.n_heads
|
||||
self.n_layers = self.args.n_layers
|
||||
self.kernel_size = self.args.kernel_size
|
||||
self.p_dropout = self.args.p_dropout
|
||||
self.resblock = self.args.resblock
|
||||
self.resblock_kernel_sizes = self.args.resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = self.args.resblock_dilation_sizes
|
||||
self.upsample_rates = self.args.upsample_rates
|
||||
self.upsample_initial_channel = self.args.upsample_initial_channel
|
||||
self.upsample_kernel_sizes = self.args.upsample_kernel_sizes
|
||||
self.segment_size = self.args.segment_size
|
||||
self.gin_channels = self.args.gin_channels
|
||||
self.ssl_dim = self.args.ssl_dim
|
||||
self.use_spk = self.args.use_spk
|
||||
|
||||
self.enc_p = Encoder(self.args.ssl_dim, self.inter_channels, self.hidden_channels, 5, 1, 16)
|
||||
self.dec = Generator(
|
||||
self.inter_channels,
|
||||
self.resblock,
|
||||
self.resblock_kernel_sizes,
|
||||
self.resblock_dilation_sizes,
|
||||
self.upsample_rates,
|
||||
self.upsample_initial_channel,
|
||||
self.upsample_kernel_sizes,
|
||||
gin_channels=self.gin_channels,
|
||||
)
|
||||
self.enc_q = Encoder(
|
||||
self.spec_channels, self.inter_channels, self.hidden_channels, 5, 1, 16, gin_channels=self.gin_channels
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(
|
||||
self.inter_channels, self.hidden_channels, 5, 1, 4, gin_channels=self.gin_channels
|
||||
)
|
||||
if not self.use_spk:
|
||||
self.enc_spk = SpeakerEncoder(model_hidden_size=self.gin_channels, model_embedding_size=self.gin_channels)
|
||||
else:
|
||||
self.load_pretrained_speaker_encoder()
|
||||
|
||||
self.wavlm = get_wavlm()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def load_pretrained_speaker_encoder(self):
|
||||
"""Load pretrained speaker encoder model as mentioned in the paper."""
|
||||
print(" > Loading pretrained speaker encoder model ...")
|
||||
self.enc_spk_ex = SpeakerEncoderEx(
|
||||
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt"
|
||||
)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
self.num_spks = self.args.num_spks
|
||||
if self.speaker_manager:
|
||||
self.num_spks = self.speaker_manager.num_spks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
c: torch.Tensor,
|
||||
spec: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
mel: Optional[torch.Tensor] = None,
|
||||
c_lengths: Optional[torch.Tensor] = None,
|
||||
spec_lengths: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
]:
|
||||
"""
|
||||
Forward pass of the model.
|
||||
|
||||
Args:
|
||||
c: WavLM features. Shape: (batch_size, c_seq_len).
|
||||
spec: The input spectrogram. Shape: (batch_size, spec_seq_len, spec_dim).
|
||||
g: The speaker embedding. Shape: (batch_size, spk_emb_dim).
|
||||
mel: The input mel-spectrogram for the speaker encoder. Shape: (batch_size, mel_seq_len, mel_dim).
|
||||
c_lengths: The lengths of the WavLM features. Shape: (batch_size,).
|
||||
spec_lengths: The lengths of the spectrogram. Shape: (batch_size,).
|
||||
|
||||
Returns:
|
||||
o: The output spectrogram. Shape: (batch_size, spec_seq_len, spec_dim).
|
||||
ids_slice: The slice indices. Shape: (batch_size, num_slices).
|
||||
spec_mask: The spectrogram mask. Shape: (batch_size, spec_seq_len).
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q): A tuple of latent variables.
|
||||
"""
|
||||
|
||||
# If c_lengths is None, set it to the length of the last dimension of c
|
||||
if c_lengths is None:
|
||||
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
||||
|
||||
# If spec_lengths is None, set it to the length of the last dimension of spec
|
||||
if spec_lengths is None:
|
||||
spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
|
||||
|
||||
# If use_spk is False, compute g from mel using enc_spk
|
||||
g = None
|
||||
if not self.use_spk:
|
||||
g = self.enc_spk(mel).unsqueeze(-1)
|
||||
|
||||
# Compute m_p, logs_p, z, m_q, logs_q, and spec_mask using enc_p and enc_q
|
||||
_, m_p, logs_p, _ = self.enc_p(c, c_lengths)
|
||||
z, m_q, logs_q, spec_mask = self.enc_q(spec.transpose(1, 2), spec_lengths, g=g)
|
||||
|
||||
# Compute z_p using flow
|
||||
z_p = self.flow(z, spec_mask, g=g)
|
||||
|
||||
# Randomly slice z and compute o using dec
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
|
||||
o = self.dec(z_slice, g=g)
|
||||
|
||||
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c, g=None, mel=None, c_lengths=None):
|
||||
"""
|
||||
Inference pass of the model
|
||||
|
||||
Args:
|
||||
c (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len).
|
||||
g (torch.Tensor): Speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
|
||||
mel (torch.Tensor): Mel-spectrogram tensor. Shape: (batch_size, mel_seq_len, mel_dim).
|
||||
c_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
"""
|
||||
if c_lengths == None:
|
||||
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
||||
if not self.use_spk:
|
||||
g = self.enc_spk.embed_utterance(mel)
|
||||
g = g.unsqueeze(-1)
|
||||
z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
|
||||
z = self.flow(z_p, c_mask, g=g, reverse=True)
|
||||
o = self.dec(z * c_mask, g=g)
|
||||
return o
|
||||
|
||||
def extract_wavlm_features(self, y):
|
||||
"""Extract WavLM features from an audio tensor.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Audio tensor. Shape: (batch_size, audio_seq_len).
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
c = self.wavlm.extract_features(y)[0]
|
||||
c = c.transpose(1, 2)
|
||||
return c
|
||||
|
||||
def load_audio(self, wav):
|
||||
"""Read and format the input audio."""
|
||||
if isinstance(wav, str):
|
||||
wav, _ = librosa.load(wav, sr=self.config.audio.input_sample_rate)
|
||||
if isinstance(wav, np.ndarray):
|
||||
wav = torch.from_numpy(wav).to(self.device)
|
||||
if isinstance(wav, torch.Tensor):
|
||||
wav = wav.to(self.device)
|
||||
if isinstance(wav, list):
|
||||
wav = torch.from_numpy(np.array(wav)).to(self.device)
|
||||
return wav.float()
|
||||
|
||||
@torch.inference_mode()
|
||||
def voice_conversion(self, src, tgt):
|
||||
"""
|
||||
Voice conversion pass of the model.
|
||||
|
||||
Args:
|
||||
src (str or torch.Tensor): Source utterance.
|
||||
tgt (str or torch.Tensor): Target utterance.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
"""
|
||||
|
||||
wav_tgt = self.load_audio(tgt).cpu().numpy()
|
||||
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
|
||||
|
||||
if self.config.model_args.use_spk:
|
||||
g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt)
|
||||
g_tgt = torch.from_numpy(g_tgt)[None, :, None].to(self.device)
|
||||
else:
|
||||
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device)
|
||||
mel_tgt = mel_spectrogram_torch(
|
||||
wav_tgt,
|
||||
self.config.audio.filter_length,
|
||||
self.config.audio.n_mel_channels,
|
||||
self.config.audio.input_sample_rate,
|
||||
self.config.audio.hop_length,
|
||||
self.config.audio.win_length,
|
||||
self.config.audio.mel_fmin,
|
||||
self.config.audio.mel_fmax,
|
||||
)
|
||||
# src
|
||||
wav_src = self.load_audio(src)
|
||||
c = self.extract_wavlm_features(wav_src[None, :])
|
||||
|
||||
if self.config.model_args.use_spk:
|
||||
audio = self.inference(c, g=g_tgt)
|
||||
else:
|
||||
audio = self.inference(c, mel=mel_tgt.transpose(1, 2))
|
||||
audio = audio[0][0].data.cpu().float().numpy()
|
||||
return audio
|
||||
|
||||
def eval_step():
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||
model = FreeVC(config)
|
||||
return model
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False, strict=True, cache=False):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"], strict=strict)
|
||||
if eval:
|
||||
self.eval()
|
||||
|
||||
def train_step():
|
||||
...
|
||||
@@ -0,0 +1,164 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
return kl
|
||||
|
||||
|
||||
def rand_gumbel(shape):
|
||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||
return -torch.log(-torch.log(uniform_samples))
|
||||
|
||||
|
||||
def rand_gumbel_like(x):
|
||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||
return g
|
||||
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def rand_spec_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||
)
|
||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||
signal = signal.view(1, channels, length)
|
||||
return signal
|
||||
|
||||
|
||||
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||
|
||||
|
||||
def subsequent_mask(length):
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2, 3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
clip_value = float(clip_value)
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
return total_norm
|
||||
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
@@ -0,0 +1,387 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DDSConv(nn.Module):
|
||||
"""
|
||||
Dialted and Depth-Separable Convolution
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size**i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(
|
||||
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
|
||||
)
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.drop(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
||||
super(WN, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = (kernel_size,)
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||
)
|
||||
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
||||
x = (x + res_acts) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.gin_channels != 0:
|
||||
remove_parametrizations(self.cond_layer, "weight")
|
||||
for l in self.in_layers:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.res_skip_layers:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class Log(nn.Module):
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
||||
logdet = torch.sum(-y, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = torch.exp(x) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Flip(nn.Module):
|
||||
def forward(self, x, *args, reverse=False, **kwargs):
|
||||
x = torch.flip(x, [1])
|
||||
if not reverse:
|
||||
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class ElementwiseAffine(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.m = nn.Parameter(torch.zeros(channels, 1))
|
||||
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
||||
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = self.m + torch.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
p_dropout=0,
|
||||
gin_channels=0,
|
||||
mean_only=False,
|
||||
):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.half_channels = channels // 2
|
||||
self.mean_only = mean_only
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
self.enc = WN(
|
||||
hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels
|
||||
)
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
self.post.weight.data.zero_()
|
||||
self.post.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0) * x_mask
|
||||
h = self.enc(h, x_mask, g=g)
|
||||
stats = self.post(h) * x_mask
|
||||
if not self.mean_only:
|
||||
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
||||
else:
|
||||
m = stats
|
||||
logs = torch.zeros_like(m)
|
||||
|
||||
if not reverse:
|
||||
x1 = m + x1 * torch.exp(logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
logdet = torch.sum(logs, [1, 2])
|
||||
return x, logdet
|
||||
else:
|
||||
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
return x
|
||||
@@ -0,0 +1,65 @@
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
# import webrtcvad
|
||||
import librosa
|
||||
import numpy as np
|
||||
from scipy.ndimage.morphology import binary_dilation
|
||||
|
||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
||||
|
||||
int16_max = (2**15) - 1
|
||||
|
||||
|
||||
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None):
|
||||
"""
|
||||
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
||||
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
||||
|
||||
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
||||
just .wav), either the waveform as a numpy array of floats.
|
||||
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
||||
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
||||
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
||||
this argument will be ignored.
|
||||
"""
|
||||
# Load the wav from disk if needed
|
||||
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
||||
wav, source_sr = librosa.load(fpath_or_wav, sr=None)
|
||||
else:
|
||||
wav = fpath_or_wav
|
||||
|
||||
# Resample the wav if needed
|
||||
if source_sr is not None and source_sr != sampling_rate:
|
||||
wav = librosa.resample(wav, source_sr, sampling_rate)
|
||||
|
||||
# Apply the preprocessing: normalize volume and shorten long silences
|
||||
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
||||
wav = trim_long_silences(wav)
|
||||
|
||||
return wav
|
||||
|
||||
|
||||
def wav_to_mel_spectrogram(wav):
|
||||
"""
|
||||
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
||||
Note: this not a log-mel spectrogram.
|
||||
"""
|
||||
frames = librosa.feature.melspectrogram(
|
||||
y=wav,
|
||||
sr=sampling_rate,
|
||||
n_fft=int(sampling_rate * mel_window_length / 1000),
|
||||
hop_length=int(sampling_rate * mel_window_step / 1000),
|
||||
n_mels=mel_n_channels,
|
||||
)
|
||||
return frames.astype(np.float32).T
|
||||
|
||||
|
||||
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
||||
if increase_only and decrease_only:
|
||||
raise ValueError("Both increase only and decrease only are set")
|
||||
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2))
|
||||
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
||||
return wav
|
||||
return wav * (10 ** (dBFS_change / 20))
|
||||
@@ -0,0 +1,31 @@
|
||||
## Mel-filterbank
|
||||
mel_window_length = 25 # In milliseconds
|
||||
mel_window_step = 10 # In milliseconds
|
||||
mel_n_channels = 40
|
||||
|
||||
|
||||
## Audio
|
||||
sampling_rate = 16000
|
||||
# Number of spectrogram frames in a partial utterance
|
||||
partials_n_frames = 160 # 1600 ms
|
||||
|
||||
|
||||
## Voice Activation Detection
|
||||
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
||||
# This sets the granularity of the VAD. Should not need to be changed.
|
||||
vad_window_length = 30 # In milliseconds
|
||||
# Number of frames to average together when performing the moving average smoothing.
|
||||
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
||||
vad_moving_average_width = 8
|
||||
# Maximum number of consecutive silent frames a segment can have.
|
||||
vad_max_silence_length = 6
|
||||
|
||||
|
||||
## Audio volume normalization
|
||||
audio_norm_target_dBFS = -30
|
||||
|
||||
|
||||
## Model parameters
|
||||
model_hidden_size = 256
|
||||
model_embedding_size = 256
|
||||
model_num_layers = 3
|
||||
@@ -0,0 +1,175 @@
|
||||
from pathlib import Path
|
||||
from time import perf_counter as timer
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vc.modules.freevc.speaker_encoder import audio
|
||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
||||
|
||||
|
||||
class SpeakerEncoder(nn.Module):
|
||||
def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True):
|
||||
"""
|
||||
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
|
||||
If None, defaults to cuda if it is available on your machine, otherwise the model will
|
||||
run on cpu. Outputs are always returned on the cpu, as numpy arrays.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Define the network
|
||||
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# Get the target device
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.device = device
|
||||
|
||||
# Load the pretrained model'speaker weights
|
||||
# weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
|
||||
# if not weights_fpath.exists():
|
||||
# raise Exception("Couldn't find the voice encoder pretrained model at %s." %
|
||||
# weights_fpath)
|
||||
|
||||
start = timer()
|
||||
checkpoint = load_fsspec(weights_fpath, map_location="cpu")
|
||||
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
self.to(device)
|
||||
|
||||
if verbose:
|
||||
print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start))
|
||||
|
||||
def forward(self, mels: torch.FloatTensor):
|
||||
"""
|
||||
Computes the embeddings of a batch of utterance spectrograms.
|
||||
:param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
|
||||
(batch_size, n_frames, n_channels)
|
||||
:return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
|
||||
Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
|
||||
"""
|
||||
# Pass the input through the LSTM layers and retrieve the final hidden state of the last
|
||||
# layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
|
||||
_, (hidden, _) = self.lstm(mels)
|
||||
embeds_raw = self.relu(self.linear(hidden[-1]))
|
||||
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
||||
|
||||
@staticmethod
|
||||
def compute_partial_slices(n_samples: int, rate, min_coverage):
|
||||
"""
|
||||
Computes where to split an utterance waveform and its corresponding mel spectrogram to
|
||||
obtain partial utterances of <partials_n_frames> each. Both the waveform and the
|
||||
mel spectrogram slices are returned, so as to make each partial utterance waveform
|
||||
correspond to its spectrogram.
|
||||
|
||||
The returned ranges may be indexing further than the length of the waveform. It is
|
||||
recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
|
||||
|
||||
:param n_samples: the number of samples in the waveform
|
||||
:param rate: how many partial utterances should occur per second. Partial utterances must
|
||||
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
||||
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
||||
the minimum rate is thus 0.625.
|
||||
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
||||
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
||||
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
||||
it will be discarded. If there aren't enough frames for one partial utterance,
|
||||
this parameter is ignored so that the function always returns at least one slice.
|
||||
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
||||
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
||||
utterances.
|
||||
"""
|
||||
assert 0 < min_coverage <= 1
|
||||
|
||||
# Compute how many frames separate two partial utterances
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
||||
assert 0 < frame_step, "The rate is too high"
|
||||
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % (
|
||||
sampling_rate / (samples_per_frame * partials_n_frames)
|
||||
)
|
||||
|
||||
# Compute the slices
|
||||
wav_slices, mel_slices = [], []
|
||||
steps = max(1, n_frames - partials_n_frames + frame_step + 1)
|
||||
for i in range(0, steps, frame_step):
|
||||
mel_range = np.array([i, i + partials_n_frames])
|
||||
wav_range = mel_range * samples_per_frame
|
||||
mel_slices.append(slice(*mel_range))
|
||||
wav_slices.append(slice(*wav_range))
|
||||
|
||||
# Evaluate whether extra padding is warranted or not
|
||||
last_wav_range = wav_slices[-1]
|
||||
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
||||
if coverage < min_coverage and len(mel_slices) > 1:
|
||||
mel_slices = mel_slices[:-1]
|
||||
wav_slices = wav_slices[:-1]
|
||||
|
||||
return wav_slices, mel_slices
|
||||
|
||||
def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
|
||||
"""
|
||||
Computes an embedding for a single utterance. The utterance is divided in partial
|
||||
utterances and an embedding is computed for each. The complete utterance embedding is the
|
||||
L2-normed average embedding of the partial utterances.
|
||||
|
||||
TODO: independent batched version of this function
|
||||
|
||||
:param wav: a preprocessed utterance waveform as a numpy array of float32
|
||||
:param return_partials: if True, the partial embeddings will also be returned along with
|
||||
the wav slices corresponding to each partial utterance.
|
||||
:param rate: how many partial utterances should occur per second. Partial utterances must
|
||||
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
||||
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
||||
the minimum rate is thus 0.625.
|
||||
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
||||
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
||||
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
||||
it will be discarded. If there aren't enough frames for one partial utterance,
|
||||
this parameter is ignored so that the function always returns at least one slice.
|
||||
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
||||
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
||||
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
||||
returned.
|
||||
"""
|
||||
# Compute where to split the utterance into partials and pad the waveform with zeros if
|
||||
# the partial utterances cover a larger range.
|
||||
wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
|
||||
max_wave_length = wav_slices[-1].stop
|
||||
if max_wave_length >= len(wav):
|
||||
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||
|
||||
# Split the utterance into partials and forward them through the model
|
||||
mel = audio.wav_to_mel_spectrogram(wav)
|
||||
mels = np.array([mel[s] for s in mel_slices])
|
||||
with torch.no_grad():
|
||||
mels = torch.from_numpy(mels).to(self.device)
|
||||
partial_embeds = self(mels).cpu().numpy()
|
||||
|
||||
# Compute the utterance embedding from the partial embeddings
|
||||
raw_embed = np.mean(partial_embeds, axis=0)
|
||||
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||
|
||||
if return_partials:
|
||||
return embed, partial_embeds, wav_slices
|
||||
return embed
|
||||
|
||||
def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
|
||||
"""
|
||||
Compute the embedding of a collection of wavs (presumably from the same speaker) by
|
||||
averaging their embedding and L2-normalizing it.
|
||||
|
||||
:param wavs: list of wavs a numpy arrays of float32.
|
||||
:param kwargs: extra arguments to embed_utterance()
|
||||
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
|
||||
"""
|
||||
raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
|
||||
return raw_embed / np.linalg.norm(raw_embed, 2)
|
||||
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
|
||||
|
||||
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"
|
||||
|
||||
|
||||
def get_wavlm(device="cpu"):
|
||||
"""Download the model and return the model object."""
|
||||
|
||||
output_path = get_user_data_dir("tts")
|
||||
|
||||
output_path = os.path.join(output_path, "wavlm")
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
output_path = os.path.join(output_path, "WavLM-Large.pt")
|
||||
if not os.path.exists(output_path):
|
||||
print(f" > Downloading WavLM model to {output_path} ...")
|
||||
urllib.request.urlretrieve(model_uri, output_path)
|
||||
|
||||
checkpoint = torch.load(output_path, map_location=torch.device(device))
|
||||
cfg = WavLMConfig(checkpoint["cfg"])
|
||||
wavlm = WavLM(cfg).to(device)
|
||||
wavlm.load_state_dict(checkpoint["model"])
|
||||
wavlm.eval()
|
||||
return wavlm
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wavlm = get_wavlm()
|
||||
@@ -0,0 +1,99 @@
|
||||
{
|
||||
"_name_or_path": "./wavlm-large/",
|
||||
"activation_dropout": 0.0,
|
||||
"adapter_kernel_size": 3,
|
||||
"adapter_stride": 2,
|
||||
"add_adapter": false,
|
||||
"apply_spec_augment": true,
|
||||
"architectures": [
|
||||
"WavLMModel"
|
||||
],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 1,
|
||||
"classifier_proj_size": 256,
|
||||
"codevector_dim": 768,
|
||||
"contrastive_logits_temperature": 0.1,
|
||||
"conv_bias": false,
|
||||
"conv_dim": [
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"conv_kernel": [
|
||||
10,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"conv_stride": [
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"ctc_loss_reduction": "sum",
|
||||
"ctc_zero_infinity": false,
|
||||
"diversity_loss_weight": 0.1,
|
||||
"do_stable_layer_norm": true,
|
||||
"eos_token_id": 2,
|
||||
"feat_extract_activation": "gelu",
|
||||
"feat_extract_dropout": 0.0,
|
||||
"feat_extract_norm": "layer",
|
||||
"feat_proj_dropout": 0.1,
|
||||
"feat_quantizer_dropout": 0.0,
|
||||
"final_dropout": 0.0,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"layerdrop": 0.1,
|
||||
"mask_channel_length": 10,
|
||||
"mask_channel_min_space": 1,
|
||||
"mask_channel_other": 0.0,
|
||||
"mask_channel_prob": 0.0,
|
||||
"mask_channel_selection": "static",
|
||||
"mask_feature_length": 10,
|
||||
"mask_feature_min_masks": 0,
|
||||
"mask_feature_prob": 0.0,
|
||||
"mask_time_length": 10,
|
||||
"mask_time_min_masks": 2,
|
||||
"mask_time_min_space": 1,
|
||||
"mask_time_other": 0.0,
|
||||
"mask_time_prob": 0.075,
|
||||
"mask_time_selection": "static",
|
||||
"max_bucket_distance": 800,
|
||||
"model_type": "wavlm",
|
||||
"num_adapter_layers": 3,
|
||||
"num_attention_heads": 16,
|
||||
"num_buckets": 320,
|
||||
"num_codevector_groups": 2,
|
||||
"num_codevectors_per_group": 320,
|
||||
"num_conv_pos_embedding_groups": 16,
|
||||
"num_conv_pos_embeddings": 128,
|
||||
"num_ctc_classes": 80,
|
||||
"num_feat_extract_layers": 7,
|
||||
"num_hidden_layers": 24,
|
||||
"num_negatives": 100,
|
||||
"output_hidden_size": 1024,
|
||||
"pad_token_id": 0,
|
||||
"proj_codevector_dim": 768,
|
||||
"replace_prob": 0.5,
|
||||
"tokenizer_class": "Wav2Vec2CTCTokenizer",
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.15.0.dev0",
|
||||
"use_weighted_layer_sum": false,
|
||||
"vocab_size": 32
|
||||
}
|
||||
@@ -0,0 +1,768 @@
|
||||
# --------------------------------------------------------
|
||||
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
||||
# Copyright (c) 2021 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class TransposeLast(nn.Module):
|
||||
def __init__(self, deconstruct_idx=None):
|
||||
super().__init__()
|
||||
self.deconstruct_idx = deconstruct_idx
|
||||
|
||||
def forward(self, x):
|
||||
if self.deconstruct_idx is not None:
|
||||
x = x[self.deconstruct_idx]
|
||||
return x.transpose(-2, -1)
|
||||
|
||||
|
||||
class Fp32LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.layer_norm(
|
||||
input.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
class Fp32GroupNorm(nn.GroupNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.group_norm(
|
||||
input.float(),
|
||||
self.num_groups,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
class GradMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.new(x)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad * ctx.scale, None
|
||||
|
||||
|
||||
class SamePad(nn.Module):
|
||||
def __init__(self, kernel_size, causal=False):
|
||||
super().__init__()
|
||||
if causal:
|
||||
self.remove = kernel_size - 1
|
||||
else:
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""Swish function"""
|
||||
|
||||
def __init__(self):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(Swish, self).__init__()
|
||||
self.act = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.act(x)
|
||||
|
||||
|
||||
class GLU_Linear(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||
super(GLU_Linear, self).__init__()
|
||||
|
||||
self.glu_type = glu_type
|
||||
self.output_dim = output_dim
|
||||
|
||||
if glu_type == "sigmoid":
|
||||
self.glu_act = torch.nn.Sigmoid()
|
||||
elif glu_type == "swish":
|
||||
self.glu_act = Swish()
|
||||
elif glu_type == "relu":
|
||||
self.glu_act = torch.nn.ReLU()
|
||||
elif glu_type == "gelu":
|
||||
self.glu_act = torch.nn.GELU()
|
||||
|
||||
if bias_in_glu:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||
else:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||
|
||||
def forward(self, x):
|
||||
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||
x = self.linear(x)
|
||||
|
||||
if self.glu_type == "bilinear":
|
||||
x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
|
||||
else:
|
||||
x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gelu_accurate(x):
|
||||
if not hasattr(gelu_accurate, "_a"):
|
||||
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||
return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||
|
||||
|
||||
def get_activation_fn(activation: str):
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_fast":
|
||||
warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
|
||||
return gelu_accurate
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "glu":
|
||||
return lambda x: x
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
def init_bert_params(module):
|
||||
"""
|
||||
Initialize the weights specific to the BERT Model.
|
||||
This overrides the default initializations depending on the specified arguments.
|
||||
1. If normal_init_linear_weights is set then weights of linear
|
||||
layer will be initialized using the normal distribution and
|
||||
bais will be set to the specified value.
|
||||
2. If normal_init_embed_weights is set then weights of embedding
|
||||
layer will be initialized using the normal distribution.
|
||||
3. If normal_init_proj_weights is set then weights of
|
||||
in_project_weight for MultiHeadAttention initialized using
|
||||
the normal distribution (to be validated).
|
||||
"""
|
||||
|
||||
def normal_(data):
|
||||
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||
# so that the RNG is consistent with and without FSDP
|
||||
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
normal_(module.weight.data)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, MultiheadAttention):
|
||||
normal_(module.q_proj.weight.data)
|
||||
normal_(module.k_proj.weight.data)
|
||||
normal_(module.v_proj.weight.data)
|
||||
|
||||
|
||||
def quant_noise(module, p, block_size):
|
||||
"""
|
||||
Wraps modules and applies quantization noise to the weights for
|
||||
subsequent quantization with Iterative Product Quantization as
|
||||
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||
|
||||
Args:
|
||||
- module: nn.Module
|
||||
- p: amount of Quantization Noise
|
||||
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||
|
||||
Remarks:
|
||||
- Module weights must have the right sizes wrt the block size
|
||||
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||
- For more detail on how to quantize by blocks with convolutional weights,
|
||||
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||
- We implement the simplest form of noise here as stated in the paper
|
||||
which consists in randomly dropping blocks
|
||||
"""
|
||||
|
||||
# if no quantization noise, don't register hook
|
||||
if p <= 0:
|
||||
return module
|
||||
|
||||
# supported modules
|
||||
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||
|
||||
# test whether module.weight has the right sizes wrt block_size
|
||||
is_conv = module.weight.ndim == 4
|
||||
|
||||
# 2D matrix
|
||||
if not is_conv:
|
||||
assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
|
||||
|
||||
# 4D matrix
|
||||
else:
|
||||
# 1x1 convolutions
|
||||
if module.kernel_size == (1, 1):
|
||||
assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
|
||||
# regular convolutions
|
||||
else:
|
||||
k = module.kernel_size[0] * module.kernel_size[1]
|
||||
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||
|
||||
def _forward_pre_hook(mod, input):
|
||||
# no noise for evaluation
|
||||
if mod.training:
|
||||
if not is_conv:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_features = weight.size(1)
|
||||
out_features = weight.size(0)
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||
|
||||
else:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_channels = mod.in_channels
|
||||
out_channels = mod.out_channels
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
if mod.kernel_size == (1, 1):
|
||||
mask = torch.zeros(
|
||||
int(in_channels // block_size * out_channels),
|
||||
device=weight.device,
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||
else:
|
||||
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||
|
||||
# scale weights and apply mask
|
||||
mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
|
||||
s = 1 / (1 - p)
|
||||
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||
|
||||
module.register_forward_pre_hook(_forward_pre_hook)
|
||||
return module
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
|
||||
See "Attention Is All You Need" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
q_noise=0.0,
|
||||
qn_block_size=8,
|
||||
has_relative_attention_bias=False,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
gru_rel_pos=False,
|
||||
rescale_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout_module = nn.Dropout(dropout)
|
||||
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
if self.has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.q_head_dim = self.head_dim
|
||||
self.k_head_dim = self.head_dim
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
assert not self.self_attention or self.qkv_same_dim, (
|
||||
"Self-attention requires query, key and " "value to be of the same size"
|
||||
)
|
||||
|
||||
k_bias = True
|
||||
if rescale_init:
|
||||
k_bias = False
|
||||
|
||||
k_embed_dim = embed_dim
|
||||
q_embed_dim = embed_dim
|
||||
|
||||
self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
|
||||
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
|
||||
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self.gru_rel_pos = gru_rel_pos
|
||||
if self.gru_rel_pos:
|
||||
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
nn.init.xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
nn.init.xavier_normal_(self.bias_v)
|
||||
if self.has_relative_attention_bias:
|
||||
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||
|
||||
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||
num_buckets = self.num_buckets
|
||||
max_distance = self.max_distance
|
||||
relative_buckets = 0
|
||||
|
||||
if bidirectional:
|
||||
num_buckets = num_buckets // 2
|
||||
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||||
relative_positions = torch.abs(relative_positions)
|
||||
else:
|
||||
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_positions < max_exact
|
||||
|
||||
relative_postion_if_large = max_exact + (
|
||||
torch.log(relative_positions.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_postion_if_large = torch.min(
|
||||
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||
)
|
||||
|
||||
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length):
|
||||
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
|
||||
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(relative_position_bucket)
|
||||
values = values.permute([2, 0, 1])
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = True,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
before_softmax: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
position_bias: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
"""Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
need_weights (bool, optional): return the attention weights,
|
||||
averaged over heads (default: False).
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
before_softmax (bool, optional): return the raw attention
|
||||
weights and values before the attention softmax.
|
||||
need_head_weights (bool, optional): return the attention
|
||||
weights for each head. Implies *need_weights*. Default:
|
||||
return the average attention weights over all heads.
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
is_tpu = query.device.type == "xla"
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
if key is not None:
|
||||
src_len, key_bsz, _ = key.size()
|
||||
if not torch.jit.is_scripting():
|
||||
assert key_bsz == bsz
|
||||
assert value is not None
|
||||
assert src_len, bsz == value.shape[:2]
|
||||
|
||||
if self.has_relative_attention_bias and position_bias is None:
|
||||
position_bias = self.compute_bias(tgt_len, src_len)
|
||||
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if (
|
||||
not is_tpu # don't use PyTorch version on TPUs
|
||||
and incremental_state is None
|
||||
and not static_kv
|
||||
# A workaround for quantization to work. Otherwise JIT compilation
|
||||
# treats bias in linear module as method.
|
||||
and not torch.jit.is_scripting()
|
||||
and self.q_head_dim == self.head_dim
|
||||
):
|
||||
assert key is not None and value is not None
|
||||
assert attn_mask is None
|
||||
|
||||
attn_mask_rel_pos = None
|
||||
if position_bias is not None:
|
||||
attn_mask_rel_pos = position_bias
|
||||
if self.gru_rel_pos:
|
||||
query_layer = query.transpose(0, 1)
|
||||
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
||||
query_layer = query_layer.view(*new_x_shape)
|
||||
query_layer = query_layer.permute(0, 2, 1, 3)
|
||||
_B, _H, _L, __ = query_layer.size()
|
||||
|
||||
gate_a, gate_b = torch.sigmoid(
|
||||
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
|
||||
).chunk(2, dim=-1)
|
||||
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
||||
|
||||
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
||||
k_proj_bias = self.k_proj.bias
|
||||
if k_proj_bias is None:
|
||||
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
||||
|
||||
x, attn = F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
torch.empty([0]),
|
||||
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout_module.p,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
self.training,
|
||||
# self.training or self.dropout_module.apply_during_inference,
|
||||
key_padding_mask,
|
||||
need_weights,
|
||||
attn_mask_rel_pos,
|
||||
use_separate_proj_weight=True,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
)
|
||||
return x, attn, position_bias
|
||||
|
||||
if incremental_state is not None:
|
||||
saved_state = self._get_input_buffer(incremental_state)
|
||||
if saved_state is not None and "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute
|
||||
# key and value if they are static
|
||||
if static_kv:
|
||||
assert self.encoder_decoder_attention and not self.self_attention
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
|
||||
if self.self_attention:
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
elif self.encoder_decoder_attention:
|
||||
# encoder-decoder attention
|
||||
q = self.q_proj(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
|
||||
else:
|
||||
assert key is not None and value is not None
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
|
||||
if self.bias_k is not None:
|
||||
assert self.bias_v is not None
|
||||
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
if saved_state is not None:
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
assert _prev_key is not None
|
||||
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
src_len = k.size(1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
prev_key_padding_mask: Optional[Tensor] = None
|
||||
if "prev_key_padding_mask" in saved_state:
|
||||
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||
assert k is not None and v is not None
|
||||
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||
key_padding_mask=key_padding_mask,
|
||||
prev_key_padding_mask=prev_key_padding_mask,
|
||||
batch_size=bsz,
|
||||
src_len=k.size(1),
|
||||
static_kv=static_kv,
|
||||
)
|
||||
|
||||
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||
# In this branch incremental_state is never None
|
||||
assert incremental_state is not None
|
||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||
assert k is not None
|
||||
assert k.size(1) == src_len
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
assert v is not None
|
||||
src_len += 1
|
||||
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||
|
||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if not is_tpu:
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if before_softmax:
|
||||
return attn_weights, v, position_bias
|
||||
|
||||
if position_bias is not None:
|
||||
if self.gru_rel_pos == 1:
|
||||
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
||||
_B, _H, _L, __ = query_layer.size()
|
||||
gate_a, gate_b = torch.sigmoid(
|
||||
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
|
||||
).chunk(2, dim=-1)
|
||||
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
||||
|
||||
position_bias = position_bias.view(attn_weights.size())
|
||||
|
||||
attn_weights = attn_weights + position_bias
|
||||
|
||||
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
attn_weights: Optional[Tensor] = None
|
||||
if need_weights:
|
||||
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
||||
if not need_head_weights:
|
||||
# average attention weights over heads
|
||||
attn_weights = attn_weights.mean(dim=0)
|
||||
|
||||
return attn, attn_weights, position_bias
|
||||
|
||||
@staticmethod
|
||||
def _append_prev_key_padding_mask(
|
||||
key_padding_mask: Optional[Tensor],
|
||||
prev_key_padding_mask: Optional[Tensor],
|
||||
batch_size: int,
|
||||
src_len: int,
|
||||
static_kv: bool,
|
||||
) -> Optional[Tensor]:
|
||||
# saved key padding masks have shape (bsz, seq_len)
|
||||
if prev_key_padding_mask is not None and static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current
|
||||
# is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
if src_len > prev_key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||
device=prev_key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask.float()
|
||||
elif key_padding_mask is not None:
|
||||
if src_len > key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - key_padding_mask.size(1)),
|
||||
device=key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
||||
else:
|
||||
new_key_padding_mask = key_padding_mask.float()
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
return new_key_padding_mask
|
||||
|
||||
def _get_input_buffer(
|
||||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||
) -> Dict[str, Optional[Tensor]]:
|
||||
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||
return empty_result
|
||||
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
buffer: Dict[str, Optional[Tensor]],
|
||||
):
|
||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||
|
||||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||
return attn_weights
|
||||
@@ -0,0 +1,719 @@
|
||||
# --------------------------------------------------------
|
||||
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
||||
# Copyright (c) 2021 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from TTS.vc.modules.freevc.wavlm.modules import (
|
||||
Fp32GroupNorm,
|
||||
Fp32LayerNorm,
|
||||
GLU_Linear,
|
||||
GradMultiply,
|
||||
MultiheadAttention,
|
||||
SamePad,
|
||||
TransposeLast,
|
||||
get_activation_fn,
|
||||
init_bert_params,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
mask_type: str = "static",
|
||||
mask_other: float = 0.0,
|
||||
min_masks: int = 0,
|
||||
no_overlap: bool = False,
|
||||
min_space: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_type: how to compute mask lengths
|
||||
static = fixed size
|
||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||
poisson = sample from possion distribution with lambda = mask length
|
||||
min_masks: minimum number of masked spans
|
||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||
"""
|
||||
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
for i in range(bsz):
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
sz = all_sz
|
||||
num_mask = all_num_mask
|
||||
|
||||
if mask_type == "static":
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
elif mask_type == "uniform":
|
||||
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||
elif mask_type == "normal":
|
||||
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
||||
lengths = [max(1, int(round(x))) for x in lengths]
|
||||
elif mask_type == "poisson":
|
||||
lengths = np.random.poisson(mask_length, size=num_mask)
|
||||
lengths = [int(round(x)) for x in lengths]
|
||||
else:
|
||||
raise Exception("unknown mask selection " + mask_type)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
lengths[0] = min(mask_length, sz - 1)
|
||||
|
||||
if no_overlap:
|
||||
mask_idc = []
|
||||
|
||||
def arrange(s, e, length, keep_length):
|
||||
span_start = np.random.randint(s, e - length)
|
||||
mask_idc.extend(span_start + i for i in range(length))
|
||||
|
||||
new_parts = []
|
||||
if span_start - s - min_space >= keep_length:
|
||||
new_parts.append((s, span_start - min_space + 1))
|
||||
if e - span_start - keep_length - min_space > keep_length:
|
||||
new_parts.append((span_start + length + min_space, e))
|
||||
return new_parts
|
||||
|
||||
parts = [(0, sz)]
|
||||
min_length = min(lengths)
|
||||
for length in sorted(lengths, reverse=True):
|
||||
lens = np.fromiter(
|
||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||
np.int,
|
||||
)
|
||||
l_sum = np.sum(lens)
|
||||
if l_sum == 0:
|
||||
break
|
||||
probs = lens / np.sum(lens)
|
||||
c = np.random.choice(len(parts), p=probs)
|
||||
s, e = parts.pop(c)
|
||||
parts.extend(arrange(s, e, length, min_length))
|
||||
mask_idc = np.asarray(mask_idc)
|
||||
else:
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
|
||||
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||
|
||||
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
||||
|
||||
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||
|
||||
min_len = min([len(m) for m in mask_idcs])
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if len(mask_idc) > min_len:
|
||||
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
class WavLMConfig:
|
||||
def __init__(self, cfg=None):
|
||||
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
||||
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||
|
||||
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||
self.activation_fn: str = "gelu" # activation function to use
|
||||
|
||||
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
||||
self.conv_bias: bool = False # include bias in conv encoder
|
||||
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
||||
|
||||
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
||||
|
||||
# dropouts
|
||||
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
||||
|
||||
# masking
|
||||
self.mask_length: int = 10 # mask length
|
||||
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
||||
self.mask_selection: str = "static" # how to choose mask length
|
||||
self.mask_other: float = (
|
||||
0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
||||
)
|
||||
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
||||
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
||||
|
||||
# channel masking
|
||||
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
||||
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
||||
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
||||
self.mask_channel_other: float = (
|
||||
0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
||||
)
|
||||
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
||||
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
||||
|
||||
# positional embeddings
|
||||
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||
|
||||
# relative position embedding
|
||||
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||
|
||||
if cfg is not None:
|
||||
self.update(cfg)
|
||||
|
||||
def update(self, cfg: dict):
|
||||
self.__dict__.update(cfg)
|
||||
|
||||
|
||||
class WavLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg: WavLMConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logger.info(f"WavLM Config: {cfg.__dict__}")
|
||||
|
||||
self.cfg = cfg
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers)
|
||||
self.embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=cfg.extractor_mode,
|
||||
conv_bias=cfg.conv_bias,
|
||||
)
|
||||
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
|
||||
)
|
||||
|
||||
self.mask_prob = cfg.mask_prob
|
||||
self.mask_selection = cfg.mask_selection
|
||||
self.mask_other = cfg.mask_other
|
||||
self.mask_length = cfg.mask_length
|
||||
self.no_mask_overlap = cfg.no_mask_overlap
|
||||
self.mask_min_space = cfg.mask_min_space
|
||||
|
||||
self.mask_channel_prob = cfg.mask_channel_prob
|
||||
self.mask_channel_selection = cfg.mask_channel_selection
|
||||
self.mask_channel_other = cfg.mask_channel_other
|
||||
self.mask_channel_length = cfg.mask_channel_length
|
||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||
|
||||
self.feature_grad_mult = cfg.feature_grad_mult
|
||||
|
||||
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
|
||||
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
def apply_mask(self, x, padding_mask):
|
||||
B, T, C = x.shape
|
||||
if self.mask_prob > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=2,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space,
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||
x[mask_indices] = self.mask_emb
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
||||
# padding_mask = padding_mask.all(-1)
|
||||
padding_mask = padding_mask.any(-1)
|
||||
return padding_mask
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = False,
|
||||
ret_conv: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
ret_layer_results: bool = False,
|
||||
):
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(source)
|
||||
if self.feature_grad_mult != 1.0:
|
||||
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
features = self.feature_extractor(source)
|
||||
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
features = self.dropout_input(features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(features, padding_mask)
|
||||
else:
|
||||
x = features
|
||||
|
||||
# feature: (B, T, D), float
|
||||
# target: (B, T), long
|
||||
# x: (B, T, D), float
|
||||
# padding_mask: (B, T), bool
|
||||
# mask_indices: (B, T), bool
|
||||
x, layer_results = self.encoder(
|
||||
x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1
|
||||
)
|
||||
|
||||
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
|
||||
|
||||
feature = res["features"] if ret_conv else res["x"]
|
||||
if ret_layer_results:
|
||||
feature = (feature, res["layer_results"])
|
||||
return feature, res["padding_mask"]
|
||||
|
||||
|
||||
class ConvFeatureExtractionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conv_layers: List[Tuple[int, int, int]],
|
||||
dropout: float = 0.0,
|
||||
mode: str = "default",
|
||||
conv_bias: bool = False,
|
||||
conv_type: str = "default",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode in {"default", "layer_norm"}
|
||||
|
||||
def block(
|
||||
n_in,
|
||||
n_out,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=False,
|
||||
is_group_norm=False,
|
||||
conv_bias=False,
|
||||
):
|
||||
def make_conv():
|
||||
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
||||
nn.init.kaiming_normal_(conv.weight)
|
||||
return conv
|
||||
|
||||
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
|
||||
|
||||
if is_layer_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Sequential(
|
||||
TransposeLast(),
|
||||
Fp32LayerNorm(dim, elementwise_affine=True),
|
||||
TransposeLast(),
|
||||
),
|
||||
nn.GELU(),
|
||||
)
|
||||
elif is_group_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
Fp32GroupNorm(dim, dim, affine=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
||||
|
||||
self.conv_type = conv_type
|
||||
if self.conv_type == "default":
|
||||
in_d = 1
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
||||
(dim, k, stride) = cl
|
||||
|
||||
self.conv_layers.append(
|
||||
block(
|
||||
in_d,
|
||||
dim,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=mode == "layer_norm",
|
||||
is_group_norm=mode == "default" and i == 0,
|
||||
conv_bias=conv_bias,
|
||||
)
|
||||
)
|
||||
in_d = dim
|
||||
elif self.conv_type == "conv2d":
|
||||
in_d = 1
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3
|
||||
(dim, k, stride) = cl
|
||||
|
||||
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
|
||||
self.conv_layers.append(torch.nn.ReLU())
|
||||
in_d = dim
|
||||
elif self.conv_type == "custom":
|
||||
in_d = 1
|
||||
idim = 80
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3
|
||||
(dim, k, stride) = cl
|
||||
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride, padding=1))
|
||||
self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
|
||||
self.conv_layers.append(torch.nn.ReLU())
|
||||
in_d = dim
|
||||
if (i + 1) % 2 == 0:
|
||||
self.conv_layers.append(torch.nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
||||
idim = int(math.ceil(idim / 2))
|
||||
else:
|
||||
pass
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
# BxT -> BxCxT
|
||||
x = x.unsqueeze(1)
|
||||
if self.conv_type == "custom":
|
||||
for conv in self.conv_layers:
|
||||
if isinstance(conv, nn.LayerNorm):
|
||||
x = x.transpose(1, 2)
|
||||
x = conv(x).transpose(1, 2)
|
||||
else:
|
||||
x = conv(x)
|
||||
x = x.transpose(2, 3).contiguous()
|
||||
x = x.view(x.size(0), -1, x.size(-1))
|
||||
else:
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
if self.conv_type == "conv2d":
|
||||
b, c, t, f = x.size()
|
||||
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = args.dropout
|
||||
self.embedding_dim = args.encoder_embed_dim
|
||||
|
||||
self.pos_conv = nn.Conv1d(
|
||||
self.embedding_dim,
|
||||
self.embedding_dim,
|
||||
kernel_size=args.conv_pos,
|
||||
padding=args.conv_pos // 2,
|
||||
groups=args.conv_pos_groups,
|
||||
)
|
||||
dropout = 0
|
||||
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||
nn.init.constant_(self.pos_conv.bias, 0)
|
||||
|
||||
self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||
|
||||
if hasattr(args, "relative_position_embedding"):
|
||||
self.relative_position_embedding = args.relative_position_embedding
|
||||
self.num_buckets = args.num_buckets
|
||||
self.max_distance = args.max_distance
|
||||
else:
|
||||
self.relative_position_embedding = False
|
||||
self.num_buckets = 0
|
||||
self.max_distance = 0
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerSentenceEncoderLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
gru_rel_pos=args.gru_rel_pos,
|
||||
)
|
||||
for i in range(args.encoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.layer_norm_first = args.layer_norm_first
|
||||
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||
self.layerdrop = args.encoder_layerdrop
|
||||
|
||||
self.apply(init_bert_params)
|
||||
|
||||
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
||||
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
||||
if padding_mask is not None:
|
||||
x[padding_mask] = 0
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x += x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
z = None
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
r = None
|
||||
pos_bias = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
dropout_probability = np.random.random()
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
x, z, pos_bias = layer(
|
||||
x,
|
||||
self_attn_padding_mask=padding_mask,
|
||||
need_weights=False,
|
||||
self_attn_mask=streaming_mask,
|
||||
pos_bias=pos_bias,
|
||||
)
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
|
||||
class TransformerSentenceEncoderLayer(nn.Module):
|
||||
"""
|
||||
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
||||
models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
ffn_embedding_dim: float = 3072,
|
||||
num_attention_heads: float = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = "relu",
|
||||
layer_norm_first: bool = False,
|
||||
has_relative_attention_bias: bool = False,
|
||||
num_buckets: int = 0,
|
||||
max_distance: int = 0,
|
||||
rescale_init: bool = False,
|
||||
gru_rel_pos: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Initialize parameters
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
|
||||
# Initialize blocks
|
||||
self.activation_name = activation_fn
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout=attention_dropout,
|
||||
self_attention=True,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
num_buckets=num_buckets,
|
||||
max_distance=max_distance,
|
||||
rescale_init=rescale_init,
|
||||
gru_rel_pos=gru_rel_pos,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.layer_norm_first = layer_norm_first
|
||||
|
||||
# layer norm associated with the self attention layer
|
||||
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
if self.activation_name == "glu":
|
||||
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||
else:
|
||||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||
|
||||
# layer norm associated with the position wise feed-forward NN
|
||||
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
self_attn_mask: torch.Tensor = None,
|
||||
self_attn_padding_mask: torch.Tensor = None,
|
||||
need_weights: bool = False,
|
||||
pos_bias=None,
|
||||
):
|
||||
"""
|
||||
LayerNorm is applied either before or after the self-attention/ffn
|
||||
modules similar to the original Transformer imlementation.
|
||||
"""
|
||||
residual = x
|
||||
|
||||
if self.layer_norm_first:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias,
|
||||
)
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
else:
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias,
|
||||
)
|
||||
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
return x, attn, pos_bias
|
||||
@@ -0,0 +1,39 @@
|
||||
# Mozilla TTS Vocoders (Experimental)
|
||||
|
||||
Here there are vocoder model implementations which can be combined with the other TTS models.
|
||||
|
||||
Currently, following models are implemented:
|
||||
|
||||
- Melgan
|
||||
- MultiBand-Melgan
|
||||
- ParallelWaveGAN
|
||||
- GAN-TTS (Discriminator Only)
|
||||
|
||||
It is also very easy to adapt different vocoder models as we provide a flexible and modular (but not too modular) framework.
|
||||
|
||||
## Training a model
|
||||
|
||||
You can see here an example (Soon)[Colab Notebook]() training MelGAN with LJSpeech dataset.
|
||||
|
||||
In order to train a new model, you need to gather all wav files into a folder and give this folder to `data_path` in '''config.json'''
|
||||
|
||||
You need to define other relevant parameters in your ```config.json``` and then start traning with the following command.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --config_path path/to/config.json```
|
||||
|
||||
Example config files can be found under `tts/vocoder/configs/` folder.
|
||||
|
||||
You can continue a previous training run by the following command.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --continue_path path/to/your/model/folder```
|
||||
|
||||
You can fine-tune a pre-trained model by the following command.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --restore_path path/to/your/model.pth```
|
||||
|
||||
Restoring a model starts a new training in a different folder. It only restores model weights with the given checkpoint file. However, continuing a training starts from the same directory where the previous training run left off.
|
||||
|
||||
You can also follow your training runs on Tensorboard as you do with our TTS models.
|
||||
|
||||
## Acknowledgement
|
||||
Thanks to @kan-bayashi for his [repository](https://github.com/kan-bayashi/ParallelWaveGAN) being the start point of our work.
|
||||
Binary file not shown.
@@ -0,0 +1,17 @@
|
||||
import importlib
|
||||
import os
|
||||
from inspect import isclass
|
||||
|
||||
# import all files under configs/
|
||||
configs_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(configs_dir):
|
||||
path = os.path.join(configs_dir, file)
|
||||
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
||||
config_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
module = importlib.import_module("TTS.vocoder.configs." + config_name)
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
|
||||
if isclass(attribute):
|
||||
# Add the class to this package's variables
|
||||
globals()[attribute_name] = attribute
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,106 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullbandMelganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for FullBand MelGAN vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import FullbandMelganConfig
|
||||
>>> config = FullbandMelganConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `fullband_melgan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'melgan_multiscale_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model parameters. Defaults to
|
||||
'{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}`
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `melgan_generator`.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict): STFT loss parameters. Default to
|
||||
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "fullband_melgan"
|
||||
|
||||
# Model specific params
|
||||
discriminator_model: str = "melgan_multiscale_discriminator"
|
||||
discriminator_model_params: dict = field(
|
||||
default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]}
|
||||
)
|
||||
generator_model: str = "melgan_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 4}
|
||||
)
|
||||
|
||||
# Training - overrides
|
||||
batch_size: int = 16
|
||||
seq_len: int = 8192
|
||||
pad_short: int = 2000
|
||||
use_noise_augment: bool = True
|
||||
use_cache: bool = True
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0.5
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 2.5
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 108
|
||||
l1_spec_loss_weight: float = 0.0
|
||||
@@ -0,0 +1,136 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class HifiganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for FullBand MelGAN vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import HifiganConfig
|
||||
>>> config = HifiganConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `hifigan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'hifigan_discriminator`.
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `hifigan_generator`.
|
||||
generator_model_params (dict): Parameters of the generator model. Defaults to
|
||||
`
|
||||
{
|
||||
"upsample_factors": [8, 8, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 4, 4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_type": "1",
|
||||
}
|
||||
`
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict):
|
||||
STFT loss parameters. Default to
|
||||
`{
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240]
|
||||
}`
|
||||
l1_spec_loss_params (dict):
|
||||
L1 spectrogram loss parameters. Default to
|
||||
`{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "hifigan"
|
||||
# model specific params
|
||||
discriminator_model: str = "hifigan_discriminator"
|
||||
generator_model: str = "hifigan_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"upsample_factors": [8, 8, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 4, 4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_type": "1",
|
||||
}
|
||||
)
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = False
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = True
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 1
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 108
|
||||
l1_spec_loss_weight: float = 45
|
||||
l1_spec_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
)
|
||||
|
||||
# optimizer parameters
|
||||
lr: float = 1e-4
|
||||
wd: float = 1e-6
|
||||
@@ -0,0 +1,106 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MelganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for MelGAN vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import MelganConfig
|
||||
>>> config = MelganConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `melgan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'melgan_multiscale_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model parameters. Defaults to
|
||||
'{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}`
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `melgan_generator`.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict): STFT loss parameters. Default to
|
||||
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "melgan"
|
||||
|
||||
# Model specific params
|
||||
discriminator_model: str = "melgan_multiscale_discriminator"
|
||||
discriminator_model_params: dict = field(
|
||||
default_factory=lambda: {"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}
|
||||
)
|
||||
generator_model: str = "melgan_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 3}
|
||||
)
|
||||
|
||||
# Training - overrides
|
||||
batch_size: int = 16
|
||||
seq_len: int = 8192
|
||||
pad_short: int = 2000
|
||||
use_noise_augment: bool = True
|
||||
use_cache: bool = True
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0.5
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 2.5
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 108
|
||||
l1_spec_loss_weight: float = 0
|
||||
@@ -0,0 +1,144 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultibandMelganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for MultiBandMelGAN vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import MultibandMelganConfig
|
||||
>>> config = MultibandMelganConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `multiband_melgan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'melgan_multiscale_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model parameters. Defaults to
|
||||
'{
|
||||
"base_channels": 16,
|
||||
"max_channels": 512,
|
||||
"downsample_factors": [4, 4, 4]
|
||||
}`
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `melgan_generator`.
|
||||
generator_model_param (dict):
|
||||
The generator model parameters. Defaults to `{"upsample_factors": [8, 4, 2], "num_res_blocks": 4}`.
|
||||
use_pqmf (bool):
|
||||
enable / disable PQMF modulation for multi-band training. Defaults to True.
|
||||
lr_gen (float):
|
||||
Initial learning rate for the generator model. Defaults to 0.0001.
|
||||
lr_disc (float):
|
||||
Initial learning rate for the discriminator model. Defaults to 0.0001.
|
||||
optimizer (torch.optim.Optimizer):
|
||||
Optimizer used for the training. Defaults to `AdamW`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
lr_scheduler_gen (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the generator. Defaults to `MultiStepLR`.
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to
|
||||
`{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`.
|
||||
lr_scheduler_disc (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the discriminator. Defaults to `MultiStepLR`.
|
||||
lr_scheduler_dict_params (dict):
|
||||
Parameters for the discriminator learning rate scheduler. Defaults to
|
||||
`{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
steps_to_start_discriminator (int):
|
||||
Number of steps required to start training the discriminator. Defaults to 0.
|
||||
use_stft_loss (bool):`
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict): STFT loss parameters. Default to
|
||||
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "multiband_melgan"
|
||||
|
||||
# Model specific params
|
||||
discriminator_model: str = "melgan_multiscale_discriminator"
|
||||
discriminator_model_params: dict = field(
|
||||
default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]}
|
||||
)
|
||||
generator_model: str = "multiband_melgan_generator"
|
||||
generator_model_params: dict = field(default_factory=lambda: {"upsample_factors": [8, 4, 2], "num_res_blocks": 4})
|
||||
use_pqmf: bool = True
|
||||
|
||||
# optimizer - overrides
|
||||
lr_gen: float = 0.0001 # Initial learning rate.
|
||||
lr_disc: float = 0.0001 # Initial learning rate.
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
|
||||
lr_scheduler_gen: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_gen_params: dict = field(
|
||||
default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}
|
||||
)
|
||||
lr_scheduler_disc: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_disc_params: dict = field(
|
||||
default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}
|
||||
)
|
||||
|
||||
# Training - overrides
|
||||
batch_size: int = 64
|
||||
seq_len: int = 16384
|
||||
pad_short: int = 2000
|
||||
use_noise_augment: bool = False
|
||||
use_cache: bool = True
|
||||
steps_to_start_discriminator: bool = 200000
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = True
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
subband_stft_loss_params: dict = field(
|
||||
default_factory=lambda: {"n_ffts": [384, 683, 171], "hop_lengths": [30, 60, 10], "win_lengths": [150, 300, 60]}
|
||||
)
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0.5
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 2.5
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 108
|
||||
l1_spec_loss_weight: float = 0
|
||||
@@ -0,0 +1,134 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelWaveganConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for ParallelWavegan vocoder.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right configuration at initialization. Defaults to `gan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'parallel_wavegan_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model kwargs. Defaults to
|
||||
'{"num_layers": 10}`
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `parallel_wavegan_generator`.
|
||||
generator_model_param (dict):
|
||||
The generator model kwargs. Defaults to `{"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30}`.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 16.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
steps_to_start_discriminator (int):
|
||||
Number of steps required to start training the discriminator. Defaults to 0.
|
||||
use_stft_loss (bool):`
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False.
|
||||
stft_loss_params (dict): STFT loss parameters. Default to
|
||||
`{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 0.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
lr_gen (float):
|
||||
Generator model initial learning rate. Defaults to 0.0002.
|
||||
lr_disc (float):
|
||||
Discriminator model initial learning rate. Defaults to 0.0002.
|
||||
optimizer (torch.optim.Optimizer):
|
||||
Optimizer used for the training. Defaults to `AdamW`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
lr_scheduler_gen (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the generator. Defaults to `ExponentialLR`.
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`.
|
||||
lr_scheduler_disc (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`.
|
||||
lr_scheduler_dict_params (dict):
|
||||
Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`.
|
||||
"""
|
||||
|
||||
model: str = "parallel_wavegan"
|
||||
|
||||
# Model specific params
|
||||
discriminator_model: str = "parallel_wavegan_discriminator"
|
||||
discriminator_model_params: dict = field(default_factory=lambda: {"num_layers": 10})
|
||||
generator_model: str = "parallel_wavegan_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30}
|
||||
)
|
||||
|
||||
# Training - overrides
|
||||
batch_size: int = 6
|
||||
seq_len: int = 25600
|
||||
pad_short: int = 2000
|
||||
use_noise_augment: bool = False
|
||||
use_cache: bool = True
|
||||
steps_to_start_discriminator: int = 200000
|
||||
target_loss: str = "loss_1"
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 0.5
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 2.5
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 0
|
||||
l1_spec_loss_weight: float = 0
|
||||
|
||||
# optimizer overrides
|
||||
lr_gen: float = 0.0002 # Initial learning rate.
|
||||
lr_disc: float = 0.0002 # Initial learning rate.
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
|
||||
lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_disc_params: dict = field(
|
||||
default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}
|
||||
)
|
||||
scheduler_after_epoch: bool = False
|
||||
@@ -0,0 +1,182 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.config import BaseAudioConfig, BaseTrainingConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVocoderConfig(BaseTrainingConfig):
|
||||
"""Shared parameters among all the vocoder models.
|
||||
Args:
|
||||
audio (BaseAudioConfig):
|
||||
Audio processor config instance. Defaultsto `BaseAudioConfig()`.
|
||||
use_noise_augment (bool):
|
||||
Augment the input audio with random noise. Defaults to False/
|
||||
eval_split_size (int):
|
||||
Number of instances used for evaluation. Defaults to 10.
|
||||
data_path (str):
|
||||
Root path of the training data. All the audio files found recursively from this root path are used for
|
||||
training. Defaults to `""`.
|
||||
feature_path (str):
|
||||
Root path to the precomputed feature files. Defaults to None.
|
||||
seq_len (int):
|
||||
Length of the waveform segments used for training. Defaults to 1000.
|
||||
pad_short (int):
|
||||
Extra padding for the waveforms shorter than `seq_len`. Defaults to 0.
|
||||
conv_path (int):
|
||||
Extra padding for the feature frames against convolution of the edge frames. Defaults to MISSING.
|
||||
Defaults to 0.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. If the RAM is not enough, if may cause OOM.
|
||||
Defaults to False.
|
||||
epochs (int):
|
||||
Number of training epochs to. Defaults to 10000.
|
||||
wd (float):
|
||||
Weight decay.
|
||||
optimizer (torch.optim.Optimizer):
|
||||
Optimizer used for the training. Defaults to `AdamW`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
# dataloading
|
||||
use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms.
|
||||
eval_split_size: int = 10 # number of samples used for evaluation.
|
||||
# dataset
|
||||
data_path: str = "" # root data path. It finds all wav files recursively from there.
|
||||
feature_path: str = None # if you use precomputed features
|
||||
seq_len: int = 1000 # signal length used in training.
|
||||
pad_short: int = 0 # additional padding for short wavs
|
||||
conv_pad: int = 0 # additional padding against convolutions applied to spectrograms
|
||||
use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM.
|
||||
# OPTIMIZER
|
||||
epochs: int = 10000 # total number of epochs to train.
|
||||
wd: float = 0.0 # Weight decay weight.
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseGANVocoderConfig(BaseVocoderConfig):
|
||||
"""Base config class used among all the GAN based vocoders.
|
||||
Args:
|
||||
use_stft_loss (bool):
|
||||
enable / disable the use of STFT loss. Defaults to True.
|
||||
use_subband_stft_loss (bool):
|
||||
enable / disable the use of Subband STFT loss. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable the use of Mean Squared Error based GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable the use of Hinge GAN loss. Defaults to True.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable feature matching loss. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable L1 spectrogram loss. Defaults to True.
|
||||
stft_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 0.
|
||||
subband_stft_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 1.
|
||||
hinge_G_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 100.
|
||||
l1_spec_loss_weight (float):
|
||||
Loss weight that multiplies the computed loss value. Defaults to 45.
|
||||
stft_loss_params (dict):
|
||||
Parameters for the STFT loss. Defaults to `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`.
|
||||
l1_spec_loss_params (dict):
|
||||
Parameters for the L1 spectrogram loss. Defaults to
|
||||
`{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}`
|
||||
target_loss (str):
|
||||
Target loss name that defines the quality of the model. Defaults to `G_avg_loss`.
|
||||
grad_clip (list):
|
||||
A list of gradient clipping theresholds for each optimizer. Any value less than 0 disables clipping.
|
||||
Defaults to [5, 5].
|
||||
lr_gen (float):
|
||||
Generator model initial learning rate. Defaults to 0.0002.
|
||||
lr_disc (float):
|
||||
Discriminator model initial learning rate. Defaults to 0.0002.
|
||||
lr_scheduler_gen (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the generator. Defaults to `ExponentialLR`.
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
|
||||
lr_scheduler_disc (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`.
|
||||
lr_scheduler_disc_params (dict):
|
||||
Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`.
|
||||
scheduler_after_epoch (bool):
|
||||
Whether to update the learning rate schedulers after each epoch. Defaults to True.
|
||||
use_pqmf (bool):
|
||||
enable / disable PQMF for subband approximation at training. Defaults to False.
|
||||
steps_to_start_discriminator (int):
|
||||
Number of steps required to start training the discriminator. Defaults to 0.
|
||||
diff_samples_for_G_and_D (bool):
|
||||
enable / disable use of different training samples for the generator and the discriminator iterations.
|
||||
Enabling it results in slower iterations but faster convergance in some cases. Defaults to False.
|
||||
"""
|
||||
|
||||
model: str = "gan"
|
||||
|
||||
# LOSS PARAMETERS
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = True
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = True
|
||||
use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN)
|
||||
use_l1_spec_loss: bool = True
|
||||
|
||||
# loss weights
|
||||
stft_loss_weight: float = 0
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 1
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 100
|
||||
l1_spec_loss_weight: float = 45
|
||||
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
|
||||
l1_spec_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
)
|
||||
|
||||
target_loss: str = "loss_0" # loss value to pick the best model to save after each epoch
|
||||
|
||||
# optimizer
|
||||
grad_clip: float = field(default_factory=lambda: [5, 5])
|
||||
lr_gen: float = 0.0002 # Initial learning rate.
|
||||
lr_disc: float = 0.0002 # Initial learning rate.
|
||||
lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
scheduler_after_epoch: bool = True
|
||||
|
||||
use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN)
|
||||
steps_to_start_discriminator = 0 # start training the discriminator after this number of steps.
|
||||
diff_samples_for_G_and_D: bool = False # use different samples for G and D training steps.
|
||||
@@ -0,0 +1,161 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnivnetConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for UnivNet vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import UnivNetConfig
|
||||
>>> config = UnivNetConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `UnivNet`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'UnivNet_discriminator`.
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `UnivNet_generator`.
|
||||
generator_model_params (dict): Parameters of the generator model. Defaults to
|
||||
`
|
||||
{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
`
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 32.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by univnet model. Defaults to False.
|
||||
stft_loss_params (dict):
|
||||
STFT loss parameters. Default to
|
||||
`{
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240]
|
||||
}`
|
||||
l1_spec_loss_params (dict):
|
||||
L1 spectrogram loss parameters. Default to
|
||||
`{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "univnet"
|
||||
batch_size: int = 32
|
||||
# model specific params
|
||||
discriminator_model: str = "univnet_discriminator"
|
||||
generator_model: str = "univnet_generator"
|
||||
generator_model_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"in_channels": 64,
|
||||
"out_channels": 1,
|
||||
"hidden_channels": 32,
|
||||
"cond_channels": 80,
|
||||
"upsample_factors": [8, 8, 4],
|
||||
"lvc_layers_each_block": 4,
|
||||
"lvc_kernel_size": 3,
|
||||
"kpnet_hidden_channels": 64,
|
||||
"kpnet_conv_size": 3,
|
||||
"dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and univnet)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 2.5
|
||||
stft_loss_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 1
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 0
|
||||
l1_spec_loss_weight: float = 0
|
||||
l1_spec_loss_params: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
)
|
||||
|
||||
# optimizer parameters
|
||||
lr_gen: float = 1e-4 # Initial learning rate.
|
||||
lr_disc: float = 1e-4 # Initial learning rate.
|
||||
lr_scheduler_gen: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
# lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
# lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
|
||||
steps_to_start_discriminator: int = 200000
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.generator_model_params["cond_channels"] = self.audio.num_mels
|
||||
@@ -0,0 +1,90 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseVocoderConfig
|
||||
from TTS.vocoder.models.wavegrad import WavegradArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavegradConfig(BaseVocoderConfig):
|
||||
"""Defines parameters for WaveGrad vocoder.
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import WavegradConfig
|
||||
>>> config = WavegradConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `wavegrad`.
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `wavegrad`.
|
||||
model_params (WavegradArgs): Model parameters. Check `WavegradArgs` for default values.
|
||||
target_loss (str):
|
||||
Target loss name that defines the quality of the model. Defaults to `avg_wavegrad_loss`.
|
||||
epochs (int):
|
||||
Number of epochs to traing the model. Defaults to 10000.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 96.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 6144.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
mixed_precision (bool):
|
||||
enable / disable mixed precision training. Default is True.
|
||||
eval_split_size (int):
|
||||
Number of samples used for evalutaion. Defaults to 50.
|
||||
train_noise_schedule (dict):
|
||||
Training noise schedule. Defaults to
|
||||
`{"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000}`
|
||||
test_noise_schedule (dict):
|
||||
Inference noise schedule. For a better performance, you may need to use `bin/tune_wavegrad.py` to find a
|
||||
better schedule. Defaults to
|
||||
`
|
||||
{
|
||||
"min_val": 1e-6,
|
||||
"max_val": 1e-2,
|
||||
"num_steps": 50,
|
||||
}
|
||||
`
|
||||
grad_clip (float):
|
||||
Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 1.0
|
||||
lr (float):
|
||||
Initila leraning rate. Defaults to 1e-4.
|
||||
lr_scheduler (str):
|
||||
One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`.
|
||||
lr_scheduler_params (dict):
|
||||
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`
|
||||
"""
|
||||
|
||||
model: str = "wavegrad"
|
||||
# Model specific params
|
||||
generator_model: str = "wavegrad"
|
||||
model_params: WavegradArgs = field(default_factory=WavegradArgs)
|
||||
target_loss: str = "loss" # loss value to pick the best model to save after each epoch
|
||||
|
||||
# Training - overrides
|
||||
epochs: int = 10000
|
||||
batch_size: int = 96
|
||||
seq_len: int = 6144
|
||||
use_cache: bool = True
|
||||
mixed_precision: bool = True
|
||||
eval_split_size: int = 50
|
||||
|
||||
# NOISE SCHEDULE PARAMS
|
||||
train_noise_schedule: dict = field(default_factory=lambda: {"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000})
|
||||
|
||||
test_noise_schedule: dict = field(
|
||||
default_factory=lambda: { # inference noise schedule. Try TTS/bin/tune_wavegrad.py to find the optimal values.
|
||||
"min_val": 1e-6,
|
||||
"max_val": 1e-2,
|
||||
"num_steps": 50,
|
||||
}
|
||||
)
|
||||
|
||||
# optimizer overrides
|
||||
grad_clip: float = 1.0
|
||||
lr: float = 1e-4 # Initial learning rate.
|
||||
lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_params: dict = field(
|
||||
default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}
|
||||
)
|
||||
@@ -0,0 +1,102 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseVocoderConfig
|
||||
from TTS.vocoder.models.wavernn import WavernnArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavernnConfig(BaseVocoderConfig):
|
||||
"""Defines parameters for Wavernn vocoder.
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import WavernnConfig
|
||||
>>> config = WavernnConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `wavernn`.
|
||||
mode (str):
|
||||
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
|
||||
Gaussian Distribution and `bits` for quantized bits as the model's output.
|
||||
mulaw (bool):
|
||||
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
|
||||
to `True`.
|
||||
generator_model (str):
|
||||
One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `WaveRNN`.
|
||||
wavernn_model_params (dict):
|
||||
kwargs for the WaveRNN model. Defaults to
|
||||
`{
|
||||
"rnn_dims": 512,
|
||||
"fc_dims": 512,
|
||||
"compute_dims": 128,
|
||||
"res_out_dims": 128,
|
||||
"num_res_blocks": 10,
|
||||
"use_aux_net": True,
|
||||
"use_upsample_net": True,
|
||||
"upsample_factors": [4, 8, 8]
|
||||
}`
|
||||
batched (bool):
|
||||
enable / disable the batched inference. It speeds up the inference by splitting the input into segments and
|
||||
processing the segments in a batch. Then it merges the outputs with a certain overlap and smoothing. If
|
||||
you set it False, without CUDA, it is too slow to be practical. Defaults to True.
|
||||
target_samples (int):
|
||||
Size of the segments in batched mode. Defaults to 11000.
|
||||
overlap_sampels (int):
|
||||
Size of the overlap between consecutive segments. Defaults to 550.
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 256.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 1280.
|
||||
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
mixed_precision (bool):
|
||||
enable / disable mixed precision training. Default is True.
|
||||
eval_split_size (int):
|
||||
Number of samples used for evalutaion. Defaults to 50.
|
||||
num_epochs_before_test (int):
|
||||
Number of epochs waited to run the next evalution. Since inference takes some time, it is better to
|
||||
wait some number of epochs not ot waste training time. Defaults to 10.
|
||||
grad_clip (float):
|
||||
Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 4.0
|
||||
lr (float):
|
||||
Initila leraning rate. Defaults to 1e-4.
|
||||
lr_scheduler (str):
|
||||
One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`.
|
||||
lr_scheduler_params (dict):
|
||||
kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [200000, 400000, 600000]}`
|
||||
"""
|
||||
|
||||
model: str = "wavernn"
|
||||
|
||||
# Model specific params
|
||||
model_args: WavernnArgs = field(default_factory=WavernnArgs)
|
||||
target_loss: str = "loss"
|
||||
|
||||
# Inference
|
||||
batched: bool = True
|
||||
target_samples: int = 11000
|
||||
overlap_samples: int = 550
|
||||
|
||||
# Training - overrides
|
||||
epochs: int = 10000
|
||||
batch_size: int = 256
|
||||
seq_len: int = 1280
|
||||
use_noise_augment: bool = False
|
||||
use_cache: bool = True
|
||||
mixed_precision: bool = True
|
||||
eval_split_size: int = 50
|
||||
num_epochs_before_test: int = (
|
||||
10 # number of epochs to wait until the next test run (synthesizing a full audio clip).
|
||||
)
|
||||
|
||||
# optimizer overrides
|
||||
grad_clip: float = 4.0
|
||||
lr: float = 1e-4 # Initial learning rate.
|
||||
lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {"gamma": 0.5, "milestones": [200000, 400000, 600000]})
|
||||
@@ -0,0 +1,58 @@
|
||||
from typing import List
|
||||
|
||||
from coqpit import Coqpit
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
|
||||
|
||||
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
|
||||
if config.model.lower() in "gan":
|
||||
dataset = GANDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||
is_training=not is_eval,
|
||||
return_segments=not is_eval,
|
||||
use_noise_augment=config.use_noise_augment,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
elif config.model.lower() == "wavegrad":
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
is_training=not is_eval,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
elif config.model.lower() == "wavernn":
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=config.model_params.pad,
|
||||
mode=config.model_params.mode,
|
||||
mulaw=config.model_params.mulaw,
|
||||
is_training=not is_eval,
|
||||
verbose=verbose,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")
|
||||
return dataset
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,152 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from multiprocessing import Manager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class GANDataset(Dataset):
|
||||
"""
|
||||
GAN Dataset searchs for all the wav files under root path
|
||||
and converts them to acoustic features on the fly and returns
|
||||
random segments of (audio, feature) couples.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
return_pairs=False,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
self.item_list = items
|
||||
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||
self.seq_len = seq_len
|
||||
self.hop_len = hop_len
|
||||
self.pad_short = pad_short
|
||||
self.conv_pad = conv_pad
|
||||
self.return_pairs = return_pairs
|
||||
self.is_training = is_training
|
||||
self.return_segments = return_segments
|
||||
self.use_cache = use_cache
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
|
||||
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
|
||||
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
|
||||
|
||||
# map G and D instances
|
||||
self.G_to_D_mappings = list(range(len(self.item_list)))
|
||||
self.shuffle_mapping()
|
||||
|
||||
# cache acoustic features
|
||||
if use_cache:
|
||||
self.create_feature_cache()
|
||||
|
||||
def create_feature_cache(self):
|
||||
self.manager = Manager()
|
||||
self.cache = self.manager.list()
|
||||
self.cache += [None for _ in range(len(self.item_list))]
|
||||
|
||||
@staticmethod
|
||||
def find_wav_files(path):
|
||||
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return different items for Generator and Discriminator and
|
||||
cache acoustic features"""
|
||||
|
||||
# set the seed differently for each worker
|
||||
if torch.utils.data.get_worker_info():
|
||||
random.seed(torch.utils.data.get_worker_info().seed)
|
||||
|
||||
if self.return_segments:
|
||||
item1 = self.load_item(idx)
|
||||
if self.return_pairs:
|
||||
idx2 = self.G_to_D_mappings[idx]
|
||||
item2 = self.load_item(idx2)
|
||||
return item1, item2
|
||||
return item1
|
||||
item1 = self.load_item(idx)
|
||||
return item1
|
||||
|
||||
def _pad_short_samples(self, audio, mel=None):
|
||||
"""Pad samples shorter than the output sequence length"""
|
||||
if len(audio) < self.seq_len:
|
||||
audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0)
|
||||
|
||||
if mel is not None and mel.shape[1] < self.feat_frame_len:
|
||||
pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0]
|
||||
mel = np.pad(
|
||||
mel,
|
||||
([0, 0], [0, self.feat_frame_len - mel.shape[1]]),
|
||||
mode="constant",
|
||||
constant_values=pad_value.mean(),
|
||||
)
|
||||
return audio, mel
|
||||
|
||||
def shuffle_mapping(self):
|
||||
random.shuffle(self.G_to_D_mappings)
|
||||
|
||||
def load_item(self, idx):
|
||||
"""load (audio, feat) couple"""
|
||||
if self.compute_feat:
|
||||
# compute features from wav
|
||||
wavpath = self.item_list[idx]
|
||||
# print(wavpath)
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio, mel = self.cache[idx]
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
audio, mel = self._pad_short_samples(audio, mel)
|
||||
else:
|
||||
# load precomputed features
|
||||
wavpath, feat_path = self.item_list[idx]
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio, mel = self.cache[idx]
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = np.load(feat_path)
|
||||
audio, mel = self._pad_short_samples(audio, mel)
|
||||
|
||||
# correct the audio length wrt padding applied in stft
|
||||
audio = np.pad(audio, (0, self.hop_len), mode="edge")
|
||||
audio = audio[: mel.shape[-1] * self.hop_len]
|
||||
assert (
|
||||
mel.shape[-1] * self.hop_len == audio.shape[-1]
|
||||
), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}"
|
||||
|
||||
audio = torch.from_numpy(audio).float().unsqueeze(0)
|
||||
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||
|
||||
if self.return_segments:
|
||||
max_mel_start = mel.shape[1] - self.feat_frame_len
|
||||
mel_start = random.randint(0, max_mel_start)
|
||||
mel_end = mel_start + self.feat_frame_len
|
||||
mel = mel[:, mel_start:mel_end]
|
||||
|
||||
audio_start = mel_start * self.hop_len
|
||||
audio = audio[:, audio_start : audio_start + self.seq_len]
|
||||
|
||||
if self.use_noise_augment and self.is_training and self.return_segments:
|
||||
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
||||
return (mel, audio)
|
||||
@@ -0,0 +1,75 @@
|
||||
import glob
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from coqpit import Coqpit
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||
"""Process wav and compute mel and quantized wave signal.
|
||||
It is mainly used by WaveRNN dataloader.
|
||||
|
||||
Args:
|
||||
out_path (str): Parent folder path to save the files.
|
||||
config (Coqpit): Model config.
|
||||
ap (AudioProcessor): Audio processor.
|
||||
"""
|
||||
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
||||
wav_files = find_wav_files(config.data_path)
|
||||
for path in tqdm(wav_files):
|
||||
wav_name = Path(path).stem
|
||||
quant_path = os.path.join(out_path, "quant", wav_name + ".npy")
|
||||
mel_path = os.path.join(out_path, "mel", wav_name + ".npy")
|
||||
y = ap.load_wav(path)
|
||||
mel = ap.melspectrogram(y)
|
||||
np.save(mel_path, mel)
|
||||
if isinstance(config.mode, int):
|
||||
quant = (
|
||||
mulaw_encode(wav=y, mulaw_qc=config.mode)
|
||||
if config.model_args.mulaw
|
||||
else quantize(x=y, quantize_bits=config.mode)
|
||||
)
|
||||
np.save(quant_path, quant)
|
||||
|
||||
|
||||
def find_wav_files(data_path, file_ext="wav"):
|
||||
wav_paths = glob.glob(os.path.join(data_path, "**", f"*.{file_ext}"), recursive=True)
|
||||
return wav_paths
|
||||
|
||||
|
||||
def find_feat_files(data_path):
|
||||
feat_paths = glob.glob(os.path.join(data_path, "**", "*.npy"), recursive=True)
|
||||
return feat_paths
|
||||
|
||||
|
||||
def load_wav_data(data_path, eval_split_size, file_ext="wav"):
|
||||
wav_paths = find_wav_files(data_path, file_ext=file_ext)
|
||||
assert len(wav_paths) > 0, f" [!] {data_path} is empty."
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(wav_paths)
|
||||
return wav_paths[:eval_split_size], wav_paths[eval_split_size:]
|
||||
|
||||
|
||||
def load_wav_feat_data(data_path, feat_path, eval_split_size):
|
||||
wav_paths = find_wav_files(data_path)
|
||||
feat_paths = find_feat_files(feat_path)
|
||||
|
||||
wav_paths.sort(key=lambda x: Path(x).stem)
|
||||
feat_paths.sort(key=lambda x: Path(x).stem)
|
||||
|
||||
assert len(wav_paths) == len(feat_paths), f" [!] {len(wav_paths)} vs {feat_paths}"
|
||||
for wav, feat in zip(wav_paths, feat_paths):
|
||||
wav_name = Path(wav).stem
|
||||
feat_name = Path(feat).stem
|
||||
assert wav_name == feat_name
|
||||
|
||||
items = list(zip(wav_paths, feat_paths))
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
@@ -0,0 +1,151 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from multiprocessing import Manager
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class WaveGradDataset(Dataset):
|
||||
"""
|
||||
WaveGrad Dataset searchs for all the wav files under root path
|
||||
and converts them to acoustic features on the fly and returns
|
||||
random segments of (audio, feature) couples.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
self.item_list = items
|
||||
self.seq_len = seq_len if return_segments else None
|
||||
self.hop_len = hop_len
|
||||
self.pad_short = pad_short
|
||||
self.conv_pad = conv_pad
|
||||
self.is_training = is_training
|
||||
self.return_segments = return_segments
|
||||
self.use_cache = use_cache
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
|
||||
if return_segments:
|
||||
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
|
||||
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
|
||||
|
||||
# cache acoustic features
|
||||
if use_cache:
|
||||
self.create_feature_cache()
|
||||
|
||||
def create_feature_cache(self):
|
||||
self.manager = Manager()
|
||||
self.cache = self.manager.list()
|
||||
self.cache += [None for _ in range(len(self.item_list))]
|
||||
|
||||
@staticmethod
|
||||
def find_wav_files(path):
|
||||
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.load_item(idx)
|
||||
return item
|
||||
|
||||
def load_test_samples(self, num_samples: int) -> List[Tuple]:
|
||||
"""Return test samples.
|
||||
|
||||
Args:
|
||||
num_samples (int): Number of samples to return.
|
||||
|
||||
Returns:
|
||||
List[Tuple]: melspectorgram and audio.
|
||||
|
||||
Shapes:
|
||||
- melspectrogram (Tensor): :math:`[C, T]`
|
||||
- audio (Tensor): :math:`[T_audio]`
|
||||
"""
|
||||
samples = []
|
||||
return_segments = self.return_segments
|
||||
self.return_segments = False
|
||||
for idx in range(num_samples):
|
||||
mel, audio = self.load_item(idx)
|
||||
samples.append([mel, audio])
|
||||
self.return_segments = return_segments
|
||||
return samples
|
||||
|
||||
def load_item(self, idx):
|
||||
"""load (audio, feat) couple"""
|
||||
# compute features from wav
|
||||
wavpath = self.item_list[idx]
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio = self.cache[idx]
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
|
||||
if self.return_segments:
|
||||
# correct audio length wrt segment length
|
||||
if audio.shape[-1] < self.seq_len + self.pad_short:
|
||||
audio = np.pad(
|
||||
audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0
|
||||
)
|
||||
assert (
|
||||
audio.shape[-1] >= self.seq_len + self.pad_short
|
||||
), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
|
||||
|
||||
# correct the audio length wrt hop length
|
||||
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
|
||||
audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0)
|
||||
|
||||
if self.use_cache:
|
||||
self.cache[idx] = audio
|
||||
|
||||
if self.return_segments:
|
||||
max_start = len(audio) - self.seq_len
|
||||
start = random.randint(0, max_start)
|
||||
end = start + self.seq_len
|
||||
audio = audio[start:end]
|
||||
|
||||
if self.use_noise_augment and self.is_training and self.return_segments:
|
||||
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
||||
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
mel = mel[..., :-1] # ignore the padding
|
||||
|
||||
audio = torch.from_numpy(audio).float()
|
||||
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||
return (mel, audio)
|
||||
|
||||
@staticmethod
|
||||
def collate_full_clips(batch):
|
||||
"""This is used in tune_wavegrad.py.
|
||||
It pads sequences to the max length."""
|
||||
max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]
|
||||
max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0]
|
||||
|
||||
mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length])
|
||||
audios = torch.zeros([len(batch), max_audio_length])
|
||||
|
||||
for idx, b in enumerate(batch):
|
||||
mel = b[0]
|
||||
audio = b[1]
|
||||
mels[idx, :, : mel.shape[1]] = mel
|
||||
audios[idx, : audio.shape[0]] = audio
|
||||
|
||||
return mels, audios
|
||||
@@ -0,0 +1,118 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
class WaveRNNDataset(Dataset):
|
||||
"""
|
||||
WaveRNN Dataset searchs for all the wav files under root path
|
||||
and converts them to acoustic features on the fly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
|
||||
):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||
self.item_list = items
|
||||
self.seq_len = seq_len
|
||||
self.hop_len = hop_len
|
||||
self.mel_len = seq_len // hop_len
|
||||
self.pad = pad
|
||||
self.mode = mode
|
||||
self.mulaw = mulaw
|
||||
self.is_training = is_training
|
||||
self.verbose = verbose
|
||||
self.return_segments = return_segments
|
||||
|
||||
assert self.seq_len % self.hop_len == 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.load_item(index)
|
||||
return item
|
||||
|
||||
def load_test_samples(self, num_samples):
|
||||
samples = []
|
||||
return_segments = self.return_segments
|
||||
self.return_segments = False
|
||||
for idx in range(num_samples):
|
||||
mel, audio, _ = self.load_item(idx)
|
||||
samples.append([mel, audio])
|
||||
self.return_segments = return_segments
|
||||
return samples
|
||||
|
||||
def load_item(self, index):
|
||||
"""
|
||||
load (audio, feat) couple if feature_path is set
|
||||
else compute it on the fly
|
||||
"""
|
||||
if self.compute_feat:
|
||||
wavpath = self.item_list[index]
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
if self.return_segments:
|
||||
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
|
||||
else:
|
||||
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
|
||||
if audio.shape[0] < min_audio_len:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
|
||||
if self.mode in ["gauss", "mold"]:
|
||||
x_input = audio
|
||||
elif isinstance(self.mode, int):
|
||||
x_input = (
|
||||
mulaw_encode(wav=audio, mulaw_qc=self.mode)
|
||||
if self.mulaw
|
||||
else quantize(x=audio, quantize_bits=self.mode)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||
|
||||
else:
|
||||
wavpath, feat_path = self.item_list[index]
|
||||
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||
|
||||
if mel.shape[-1] < self.mel_len + 2 * self.pad:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
self.item_list[index] = self.item_list[index + 1]
|
||||
feat_path = self.item_list[index]
|
||||
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||
if self.mode in ["gauss", "mold"]:
|
||||
x_input = self.ap.load_wav(wavpath)
|
||||
elif isinstance(self.mode, int):
|
||||
x_input = np.load(feat_path.replace("/mel/", "/quant/"))
|
||||
else:
|
||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||
|
||||
return mel, x_input, wavpath
|
||||
|
||||
def collate(self, batch):
|
||||
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
||||
max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch]
|
||||
|
||||
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
||||
sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets]
|
||||
|
||||
mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
|
||||
|
||||
coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)]
|
||||
|
||||
mels = np.stack(mels).astype(np.float32)
|
||||
if self.mode in ["gauss", "mold"]:
|
||||
coarse = np.stack(coarse).astype(np.float32)
|
||||
coarse = torch.FloatTensor(coarse)
|
||||
x_input = coarse[:, : self.seq_len]
|
||||
elif isinstance(self.mode, int):
|
||||
coarse = np.stack(coarse).astype(np.int64)
|
||||
coarse = torch.LongTensor(coarse)
|
||||
x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0
|
||||
y_coarse = coarse[:, 1:]
|
||||
mels = torch.FloatTensor(mels)
|
||||
return x_input, mels, y_coarse
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,56 @@
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
class ResStack(nn.Module):
|
||||
def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]):
|
||||
super().__init__()
|
||||
resstack = []
|
||||
for dilation in dilations:
|
||||
resstack += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(dilation),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(padding),
|
||||
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
||||
]
|
||||
self.resstack = nn.Sequential(*resstack)
|
||||
|
||||
self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.shortcut(x)
|
||||
x2 = self.resstack(x)
|
||||
return x1 + x2
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.shortcut, "weight")
|
||||
remove_parametrizations(self.resstack[2], "weight")
|
||||
remove_parametrizations(self.resstack[5], "weight")
|
||||
remove_parametrizations(self.resstack[8], "weight")
|
||||
remove_parametrizations(self.resstack[11], "weight")
|
||||
remove_parametrizations(self.resstack[14], "weight")
|
||||
remove_parametrizations(self.resstack[17], "weight")
|
||||
|
||||
|
||||
class MRF(nn.Module):
|
||||
def __init__(self, kernels, channel, dilations=[1, 3, 5]): # # pylint: disable=dangerous-default-value
|
||||
super().__init__()
|
||||
self.resblock1 = ResStack(kernels[0], channel, 0, dilations)
|
||||
self.resblock2 = ResStack(kernels[1], channel, 6, dilations)
|
||||
self.resblock3 = ResStack(kernels[2], channel, 12, dilations)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.resblock1(x)
|
||||
x2 = self.resblock2(x)
|
||||
x3 = self.resblock3(x)
|
||||
return x1 + x2 + x3
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.resblock1.remove_weight_norm()
|
||||
self.resblock2.remove_weight_norm()
|
||||
self.resblock3.remove_weight_norm()
|
||||
@@ -0,0 +1,368 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
|
||||
#################################
|
||||
# GENERATOR LOSSES
|
||||
#################################
|
||||
|
||||
|
||||
class STFTLoss(nn.Module):
|
||||
"""STFT loss. Input generate and real waveforms are converted
|
||||
to spectrograms compared with L1 and Spectral convergence losses.
|
||||
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||
|
||||
def __init__(self, n_fft, hop_length, win_length):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.stft = TorchSTFT(n_fft, hop_length, win_length)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
y_hat_M = self.stft(y_hat)
|
||||
y_M = self.stft(y)
|
||||
# magnitude loss
|
||||
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
|
||||
# spectral convergence loss
|
||||
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||
"""Multi-scale STFT loss. Input generate and real waveforms are converted
|
||||
to spectrograms compared with L1 and Spectral convergence losses.
|
||||
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||
|
||||
def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)):
|
||||
super().__init__()
|
||||
self.loss_funcs = torch.nn.ModuleList()
|
||||
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
|
||||
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
N = len(self.loss_funcs)
|
||||
loss_sc = 0
|
||||
loss_mag = 0
|
||||
for f in self.loss_funcs:
|
||||
lm, lsc = f(y_hat, y)
|
||||
loss_mag += lm
|
||||
loss_sc += lsc
|
||||
loss_sc /= N
|
||||
loss_mag /= N
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class L1SpecLoss(nn.Module):
|
||||
"""L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""
|
||||
|
||||
def __init__(
|
||||
self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True
|
||||
):
|
||||
super().__init__()
|
||||
self.use_mel = use_mel
|
||||
self.stft = TorchSTFT(
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
sample_rate=sample_rate,
|
||||
mel_fmin=mel_fmin,
|
||||
mel_fmax=mel_fmax,
|
||||
n_mels=n_mels,
|
||||
use_mel=use_mel,
|
||||
)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
y_hat_M = self.stft(y_hat)
|
||||
y_M = self.stft(y)
|
||||
# magnitude loss
|
||||
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
|
||||
return loss_mag
|
||||
|
||||
|
||||
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||
"""Multiscale STFT loss for multi band model outputs.
|
||||
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, y_hat, y):
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y = y.view(-1, 1, y.shape[2])
|
||||
return super().forward(y_hat.squeeze(1), y.squeeze(1))
|
||||
|
||||
|
||||
class MSEGLoss(nn.Module):
|
||||
"""Mean Squared Generator Loss"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_real):
|
||||
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
|
||||
return loss_fake
|
||||
|
||||
|
||||
class HingeGLoss(nn.Module):
|
||||
"""Hinge Discriminator Loss"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_real):
|
||||
# TODO: this might be wrong
|
||||
loss_fake = torch.mean(F.relu(1.0 - score_real))
|
||||
return loss_fake
|
||||
|
||||
|
||||
##################################
|
||||
# DISCRIMINATOR LOSSES
|
||||
##################################
|
||||
|
||||
|
||||
class MSEDLoss(nn.Module):
|
||||
"""Mean Squared Discriminator Loss"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_func = nn.MSELoss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape))
|
||||
loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
||||
class HingeDLoss(nn.Module):
|
||||
"""Hinge Discriminator Loss"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(F.relu(1.0 - score_real))
|
||||
loss_fake = torch.mean(F.relu(1.0 + score_fake))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
||||
class MelganFeatureLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_func = nn.L1Loss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, fake_feats, real_feats):
|
||||
loss_feats = 0
|
||||
num_feats = 0
|
||||
for idx, _ in enumerate(fake_feats):
|
||||
for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]):
|
||||
loss_feats += self.loss_func(fake_feat, real_feat)
|
||||
num_feats += 1
|
||||
loss_feats = loss_feats / num_feats
|
||||
return loss_feats
|
||||
|
||||
|
||||
#####################################
|
||||
# LOSS WRAPPERS
|
||||
#####################################
|
||||
|
||||
|
||||
def _apply_G_adv_loss(scores_fake, loss_func):
|
||||
"""Compute G adversarial loss function
|
||||
and normalize values"""
|
||||
adv_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
for score_fake in scores_fake:
|
||||
fake_loss = loss_func(score_fake)
|
||||
adv_loss += fake_loss
|
||||
adv_loss /= len(scores_fake)
|
||||
else:
|
||||
fake_loss = loss_func(scores_fake)
|
||||
adv_loss = fake_loss
|
||||
return adv_loss
|
||||
|
||||
|
||||
def _apply_D_loss(scores_fake, scores_real, loss_func):
|
||||
"""Compute D loss func and normalize loss values"""
|
||||
loss = 0
|
||||
real_loss = 0
|
||||
fake_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
# multi-scale loss
|
||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
||||
total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real)
|
||||
loss += total_loss
|
||||
real_loss += real_loss_
|
||||
fake_loss += fake_loss_
|
||||
# normalize loss values with number of scales (discriminators)
|
||||
loss /= len(scores_fake)
|
||||
real_loss /= len(scores_real)
|
||||
fake_loss /= len(scores_fake)
|
||||
else:
|
||||
# single scale loss
|
||||
total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real)
|
||||
loss = total_loss
|
||||
return loss, real_loss, fake_loss
|
||||
|
||||
|
||||
##################################
|
||||
# MODEL LOSSES
|
||||
##################################
|
||||
|
||||
|
||||
class GeneratorLoss(nn.Module):
|
||||
"""Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes
|
||||
losses. It allows to experiment with different combinations of loss functions with different models by just
|
||||
changing configurations.
|
||||
|
||||
Args:
|
||||
C (AttrDict): model configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, C):
|
||||
super().__init__()
|
||||
assert not (
|
||||
C.use_mse_gan_loss and C.use_hinge_gan_loss
|
||||
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False
|
||||
self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False
|
||||
self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False
|
||||
self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False
|
||||
|
||||
self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0
|
||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0
|
||||
self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0
|
||||
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0
|
||||
self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0
|
||||
self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0
|
||||
|
||||
if C.use_stft_loss:
|
||||
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
|
||||
if C.use_subband_stft_loss:
|
||||
self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params)
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEGLoss()
|
||||
if C.use_hinge_gan_loss:
|
||||
self.hinge_loss = HingeGLoss()
|
||||
if C.use_feat_match_loss:
|
||||
self.feat_match_loss = MelganFeatureLoss()
|
||||
if C.use_l1_spec_loss:
|
||||
assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"]
|
||||
self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)
|
||||
|
||||
def forward(
|
||||
self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None
|
||||
):
|
||||
gen_loss = 0
|
||||
adv_loss = 0
|
||||
return_dict = {}
|
||||
|
||||
# STFT Loss
|
||||
if self.use_stft_loss:
|
||||
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1))
|
||||
return_dict["G_stft_loss_mg"] = stft_loss_mg
|
||||
return_dict["G_stft_loss_sc"] = stft_loss_sc
|
||||
gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
||||
|
||||
# L1 Spec loss
|
||||
if self.use_l1_spec_loss:
|
||||
l1_spec_loss = self.l1_spec_loss(y_hat, y)
|
||||
return_dict["G_l1_spec_loss"] = l1_spec_loss
|
||||
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
|
||||
|
||||
# subband STFT Loss
|
||||
if self.use_subband_stft_loss:
|
||||
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
|
||||
return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg
|
||||
return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc
|
||||
gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
|
||||
|
||||
# multiscale MSE adversarial loss
|
||||
if self.use_mse_gan_loss and scores_fake is not None:
|
||||
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
|
||||
return_dict["G_mse_fake_loss"] = mse_fake_loss
|
||||
adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss
|
||||
|
||||
# multiscale Hinge adversarial loss
|
||||
if self.use_hinge_gan_loss and not scores_fake is not None:
|
||||
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
|
||||
return_dict["G_hinge_fake_loss"] = hinge_fake_loss
|
||||
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
|
||||
|
||||
# Feature Matching Loss
|
||||
if self.use_feat_match_loss and not feats_fake is None:
|
||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||
return_dict["G_feat_match_loss"] = feat_match_loss
|
||||
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
||||
return_dict["loss"] = gen_loss + adv_loss
|
||||
return_dict["G_gen_loss"] = gen_loss
|
||||
return_dict["G_adv_loss"] = adv_loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
"""Like ```GeneratorLoss```"""
|
||||
|
||||
def __init__(self, C):
|
||||
super().__init__()
|
||||
assert not (
|
||||
C.use_mse_gan_loss and C.use_hinge_gan_loss
|
||||
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEDLoss()
|
||||
if C.use_hinge_gan_loss:
|
||||
self.hinge_loss = HingeDLoss()
|
||||
|
||||
def forward(self, scores_fake, scores_real):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
||||
if self.use_mse_gan_loss:
|
||||
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(
|
||||
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss
|
||||
)
|
||||
return_dict["D_mse_gan_loss"] = mse_D_loss
|
||||
return_dict["D_mse_gan_real_loss"] = mse_D_real_loss
|
||||
return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss
|
||||
loss += mse_D_loss
|
||||
|
||||
if self.use_hinge_gan_loss:
|
||||
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(
|
||||
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss
|
||||
)
|
||||
return_dict["D_hinge_gan_loss"] = hinge_D_loss
|
||||
return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss
|
||||
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
|
||||
loss += hinge_D_loss
|
||||
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class WaveRNNLoss(nn.Module):
|
||||
def __init__(self, wave_rnn_mode: Union[str, int]):
|
||||
super().__init__()
|
||||
if wave_rnn_mode == "mold":
|
||||
self.loss_func = discretized_mix_logistic_loss
|
||||
elif wave_rnn_mode == "gauss":
|
||||
self.loss_func = gaussian_loss
|
||||
elif isinstance(wave_rnn_mode, int):
|
||||
self.loss_func = torch.nn.CrossEntropyLoss()
|
||||
else:
|
||||
raise ValueError(" [!] Unknown mode for Wavernn.")
|
||||
|
||||
def forward(self, y_hat, y) -> Dict:
|
||||
loss = self.loss_func(y_hat, y)
|
||||
return {"loss": loss}
|
||||
@@ -0,0 +1,198 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class KernelPredictor(torch.nn.Module):
|
||||
"""Kernel predictor for the location-variable convolutions"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int):
|
||||
kpnet_
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
|
||||
l_b = conv_out_channels * conv_layers
|
||||
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
self.input_conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_conv = torch.nn.Sequential(
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True)
|
||||
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
Returns:
|
||||
"""
|
||||
batch, _, cond_length = c.shape
|
||||
|
||||
c = self.input_conv(c)
|
||||
c = c + self.residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
|
||||
kernels = k.contiguous().view(
|
||||
batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length
|
||||
)
|
||||
bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length)
|
||||
return kernels, bias
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
"""the location-variable convolutions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
upsample_ratio,
|
||||
conv_layers=4,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = conv_layers
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.convs = torch.nn.ModuleList()
|
||||
|
||||
self.upsample = torch.nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=upsample_ratio * 2,
|
||||
stride=upsample_ratio,
|
||||
padding=upsample_ratio // 2 + upsample_ratio % 2,
|
||||
output_padding=upsample_ratio % 2,
|
||||
)
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=conv_layers,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout,
|
||||
)
|
||||
|
||||
for i in range(conv_layers):
|
||||
padding = (3**i) * int((conv_kernel_size - 1) / 2)
|
||||
conv = torch.nn.Conv1d(
|
||||
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i
|
||||
)
|
||||
|
||||
self.convs.append(conv)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
"""
|
||||
in_channels = x.shape[1]
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = self.upsample(x)
|
||||
|
||||
for i in range(self.conv_layers):
|
||||
y = F.leaky_relu(x, 0.2)
|
||||
y = self.convs[i](y)
|
||||
y = F.leaky_relu(y, 0.2)
|
||||
|
||||
k = kernels[:, i, :, :, :, :]
|
||||
b = bias[:, i, :, :]
|
||||
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
|
||||
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def location_variable_convolution(x, kernel, bias, dilation, hop_size):
|
||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
|
||||
assert in_length == (
|
||||
kernel_length * hop_size
|
||||
), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
x = x.unfold(
|
||||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o + bias.unsqueeze(-1).unsqueeze(-1)
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
return o
|
||||
@@ -0,0 +1,43 @@
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
class ResidualStack(nn.Module):
|
||||
def __init__(self, channels, num_res_blocks, kernel_size):
|
||||
super().__init__()
|
||||
|
||||
assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
|
||||
base_padding = (kernel_size - 1) // 2
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for idx in range(num_res_blocks):
|
||||
layer_kernel_size = kernel_size
|
||||
layer_dilation = layer_kernel_size**idx
|
||||
layer_padding = base_padding * layer_dilation
|
||||
self.blocks += [
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(
|
||||
nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True)
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
|
||||
)
|
||||
]
|
||||
|
||||
self.shortcuts = nn.ModuleList(
|
||||
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
x = shortcut(x) + block(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
remove_parametrizations(block[2], "weight")
|
||||
remove_parametrizations(block[4], "weight")
|
||||
remove_parametrizations(shortcut, "weight")
|
||||
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class ResidualBlock(torch.nn.Module):
|
||||
"""Residual block module in WaveNet."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
dropout=0.0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
# no future time stamps available
|
||||
if use_causal_conv:
|
||||
padding = (kernel_size - 1) * dilation
|
||||
else:
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.use_causal_conv = use_causal_conv
|
||||
|
||||
# dilation conv
|
||||
self.conv = torch.nn.Conv1d(
|
||||
res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias
|
||||
)
|
||||
|
||||
# local conditioning
|
||||
if aux_channels > 0:
|
||||
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
|
||||
else:
|
||||
self.conv1x1_aux = None
|
||||
|
||||
# conv output is split into two groups
|
||||
gate_out_channels = gate_channels // 2
|
||||
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
|
||||
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""
|
||||
x: B x D_res x T
|
||||
c: B x D_aux x T
|
||||
"""
|
||||
residual = x
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.conv(x)
|
||||
|
||||
# remove future time steps if use_causal_conv conv
|
||||
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
|
||||
|
||||
# split into two part for gated activation
|
||||
splitdim = 1
|
||||
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
||||
|
||||
# local conditioning
|
||||
if c is not None:
|
||||
assert self.conv1x1_aux is not None
|
||||
c = self.conv1x1_aux(c)
|
||||
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
||||
xa, xb = xa + ca, xb + cb
|
||||
|
||||
x = torch.tanh(xa) * torch.sigmoid(xb)
|
||||
|
||||
# for skip connection
|
||||
s = self.conv1x1_skip(x)
|
||||
|
||||
# for residual connection
|
||||
x = (self.conv1x1_out(x) + residual) * (0.5**2)
|
||||
|
||||
return x, s
|
||||
@@ -0,0 +1,53 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy import signal as sig
|
||||
|
||||
|
||||
# adapted from
|
||||
# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan
|
||||
class PQMF(torch.nn.Module):
|
||||
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
|
||||
super().__init__()
|
||||
|
||||
self.N = N
|
||||
self.taps = taps
|
||||
self.cutoff = cutoff
|
||||
self.beta = beta
|
||||
|
||||
QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta))
|
||||
H = np.zeros((N, len(QMF)))
|
||||
G = np.zeros((N, len(QMF)))
|
||||
for k in range(N):
|
||||
constant_factor = (
|
||||
(2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2))
|
||||
) # TODO: (taps - 1) -> taps
|
||||
phase = (-1) ** k * np.pi / 4
|
||||
H[k] = 2 * QMF * np.cos(constant_factor + phase)
|
||||
|
||||
G[k] = 2 * QMF * np.cos(constant_factor - phase)
|
||||
|
||||
H = torch.from_numpy(H[:, None, :]).float()
|
||||
G = torch.from_numpy(G[None, :, :]).float()
|
||||
|
||||
self.register_buffer("H", H)
|
||||
self.register_buffer("G", G)
|
||||
|
||||
updown_filter = torch.zeros((N, N, N)).float()
|
||||
for k in range(N):
|
||||
updown_filter[k, k, 0] = 1.0
|
||||
self.register_buffer("updown_filter", updown_filter)
|
||||
self.N = N
|
||||
|
||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.analysis(x)
|
||||
|
||||
def analysis(self, x):
|
||||
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
|
||||
|
||||
def synthesis(self, x):
|
||||
x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N)
|
||||
x = F.conv1d(x, self.G, padding=self.taps // 2)
|
||||
return x
|
||||
@@ -0,0 +1,640 @@
|
||||
0.0000000e+000
|
||||
-5.5252865e-004
|
||||
-5.6176926e-004
|
||||
-4.9475181e-004
|
||||
-4.8752280e-004
|
||||
-4.8937912e-004
|
||||
-5.0407143e-004
|
||||
-5.2265643e-004
|
||||
-5.4665656e-004
|
||||
-5.6778026e-004
|
||||
-5.8709305e-004
|
||||
-6.1327474e-004
|
||||
-6.3124935e-004
|
||||
-6.5403334e-004
|
||||
-6.7776908e-004
|
||||
-6.9416146e-004
|
||||
-7.1577365e-004
|
||||
-7.2550431e-004
|
||||
-7.4409419e-004
|
||||
-7.4905981e-004
|
||||
-7.6813719e-004
|
||||
-7.7248486e-004
|
||||
-7.8343323e-004
|
||||
-7.7798695e-004
|
||||
-7.8036647e-004
|
||||
-7.8014496e-004
|
||||
-7.7579773e-004
|
||||
-7.6307936e-004
|
||||
-7.5300014e-004
|
||||
-7.3193572e-004
|
||||
-7.2153920e-004
|
||||
-6.9179375e-004
|
||||
-6.6504151e-004
|
||||
-6.3415949e-004
|
||||
-5.9461189e-004
|
||||
-5.5645764e-004
|
||||
-5.1455722e-004
|
||||
-4.6063255e-004
|
||||
-4.0951215e-004
|
||||
-3.5011759e-004
|
||||
-2.8969812e-004
|
||||
-2.0983373e-004
|
||||
-1.4463809e-004
|
||||
-6.1733441e-005
|
||||
1.3494974e-005
|
||||
1.0943831e-004
|
||||
2.0430171e-004
|
||||
2.9495311e-004
|
||||
4.0265402e-004
|
||||
5.1073885e-004
|
||||
6.2393761e-004
|
||||
7.4580259e-004
|
||||
8.6084433e-004
|
||||
9.8859883e-004
|
||||
1.1250155e-003
|
||||
1.2577885e-003
|
||||
1.3902495e-003
|
||||
1.5443220e-003
|
||||
1.6868083e-003
|
||||
1.8348265e-003
|
||||
1.9841141e-003
|
||||
2.1461584e-003
|
||||
2.3017255e-003
|
||||
2.4625617e-003
|
||||
2.6201759e-003
|
||||
2.7870464e-003
|
||||
2.9469448e-003
|
||||
3.1125421e-003
|
||||
3.2739613e-003
|
||||
3.4418874e-003
|
||||
3.6008268e-003
|
||||
3.7603923e-003
|
||||
3.9207432e-003
|
||||
4.0819753e-003
|
||||
4.2264269e-003
|
||||
4.3730720e-003
|
||||
4.5209853e-003
|
||||
4.6606461e-003
|
||||
4.7932561e-003
|
||||
4.9137604e-003
|
||||
5.0393023e-003
|
||||
5.1407354e-003
|
||||
5.2461166e-003
|
||||
5.3471681e-003
|
||||
5.4196776e-003
|
||||
5.4876040e-003
|
||||
5.5475715e-003
|
||||
5.5938023e-003
|
||||
5.6220643e-003
|
||||
5.6455197e-003
|
||||
5.6389200e-003
|
||||
5.6266114e-003
|
||||
5.5917129e-003
|
||||
5.5404364e-003
|
||||
5.4753783e-003
|
||||
5.3838976e-003
|
||||
5.2715759e-003
|
||||
5.1382275e-003
|
||||
4.9839688e-003
|
||||
4.8109469e-003
|
||||
4.6039530e-003
|
||||
4.3801862e-003
|
||||
4.1251642e-003
|
||||
3.8456408e-003
|
||||
3.5401247e-003
|
||||
3.2091886e-003
|
||||
2.8446758e-003
|
||||
2.4508540e-003
|
||||
2.0274176e-003
|
||||
1.5784683e-003
|
||||
1.0902329e-003
|
||||
5.8322642e-004
|
||||
2.7604519e-005
|
||||
-5.4642809e-004
|
||||
-1.1568136e-003
|
||||
-1.8039473e-003
|
||||
-2.4826724e-003
|
||||
-3.1933778e-003
|
||||
-3.9401124e-003
|
||||
-4.7222596e-003
|
||||
-5.5337211e-003
|
||||
-6.3792293e-003
|
||||
-7.2615817e-003
|
||||
-8.1798233e-003
|
||||
-9.1325330e-003
|
||||
-1.0115022e-002
|
||||
-1.1131555e-002
|
||||
-1.2185000e-002
|
||||
-1.3271822e-002
|
||||
-1.4390467e-002
|
||||
-1.5540555e-002
|
||||
-1.6732471e-002
|
||||
-1.7943338e-002
|
||||
-1.9187243e-002
|
||||
-2.0453179e-002
|
||||
-2.1746755e-002
|
||||
-2.3068017e-002
|
||||
-2.4416099e-002
|
||||
-2.5787585e-002
|
||||
-2.7185943e-002
|
||||
-2.8607217e-002
|
||||
-3.0050266e-002
|
||||
-3.1501761e-002
|
||||
-3.2975408e-002
|
||||
-3.4462095e-002
|
||||
-3.5969756e-002
|
||||
-3.7481285e-002
|
||||
-3.9005368e-002
|
||||
-4.0534917e-002
|
||||
-4.2064909e-002
|
||||
-4.3609754e-002
|
||||
-4.5148841e-002
|
||||
-4.6684303e-002
|
||||
-4.8216572e-002
|
||||
-4.9738576e-002
|
||||
-5.1255616e-002
|
||||
-5.2763075e-002
|
||||
-5.4245277e-002
|
||||
-5.5717365e-002
|
||||
-5.7161645e-002
|
||||
-5.8591568e-002
|
||||
-5.9983748e-002
|
||||
-6.1345517e-002
|
||||
-6.2685781e-002
|
||||
-6.3971590e-002
|
||||
-6.5224711e-002
|
||||
-6.6436751e-002
|
||||
-6.7607599e-002
|
||||
-6.8704383e-002
|
||||
-6.9763024e-002
|
||||
-7.0762871e-002
|
||||
-7.1700267e-002
|
||||
-7.2568258e-002
|
||||
-7.3362026e-002
|
||||
-7.4100364e-002
|
||||
-7.4745256e-002
|
||||
-7.5313734e-002
|
||||
-7.5800836e-002
|
||||
-7.6199248e-002
|
||||
-7.6499217e-002
|
||||
-7.6709349e-002
|
||||
-7.6817398e-002
|
||||
-7.6823001e-002
|
||||
-7.6720492e-002
|
||||
-7.6505072e-002
|
||||
-7.6174832e-002
|
||||
-7.5730576e-002
|
||||
-7.5157626e-002
|
||||
-7.4466439e-002
|
||||
-7.3640601e-002
|
||||
-7.2677464e-002
|
||||
-7.1582636e-002
|
||||
-7.0353307e-002
|
||||
-6.8966401e-002
|
||||
-6.7452502e-002
|
||||
-6.5769067e-002
|
||||
-6.3944481e-002
|
||||
-6.1960278e-002
|
||||
-5.9816657e-002
|
||||
-5.7515269e-002
|
||||
-5.5046003e-002
|
||||
-5.2409382e-002
|
||||
-4.9597868e-002
|
||||
-4.6630331e-002
|
||||
-4.3476878e-002
|
||||
-4.0145828e-002
|
||||
-3.6641812e-002
|
||||
-3.2958393e-002
|
||||
-2.9082401e-002
|
||||
-2.5030756e-002
|
||||
-2.0799707e-002
|
||||
-1.6370126e-002
|
||||
-1.1762383e-002
|
||||
-6.9636862e-003
|
||||
-1.9765601e-003
|
||||
3.2086897e-003
|
||||
8.5711749e-003
|
||||
1.4128883e-002
|
||||
1.9883413e-002
|
||||
2.5822729e-002
|
||||
3.1953127e-002
|
||||
3.8277657e-002
|
||||
4.4780682e-002
|
||||
5.1480418e-002
|
||||
5.8370533e-002
|
||||
6.5440985e-002
|
||||
7.2694330e-002
|
||||
8.0137293e-002
|
||||
8.7754754e-002
|
||||
9.5553335e-002
|
||||
1.0353295e-001
|
||||
1.1168269e-001
|
||||
1.2000780e-001
|
||||
1.2850029e-001
|
||||
1.3715518e-001
|
||||
1.4597665e-001
|
||||
1.5496071e-001
|
||||
1.6409589e-001
|
||||
1.7338082e-001
|
||||
1.8281725e-001
|
||||
1.9239667e-001
|
||||
2.0212502e-001
|
||||
2.1197359e-001
|
||||
2.2196527e-001
|
||||
2.3206909e-001
|
||||
2.4230169e-001
|
||||
2.5264803e-001
|
||||
2.6310533e-001
|
||||
2.7366340e-001
|
||||
2.8432142e-001
|
||||
2.9507167e-001
|
||||
3.0590986e-001
|
||||
3.1682789e-001
|
||||
3.2781137e-001
|
||||
3.3887227e-001
|
||||
3.4999141e-001
|
||||
3.6115899e-001
|
||||
3.7237955e-001
|
||||
3.8363500e-001
|
||||
3.9492118e-001
|
||||
4.0623177e-001
|
||||
4.1756969e-001
|
||||
4.2891199e-001
|
||||
4.4025538e-001
|
||||
4.5159965e-001
|
||||
4.6293081e-001
|
||||
4.7424532e-001
|
||||
4.8552531e-001
|
||||
4.9677083e-001
|
||||
5.0798175e-001
|
||||
5.1912350e-001
|
||||
5.3022409e-001
|
||||
5.4125534e-001
|
||||
5.5220513e-001
|
||||
5.6307891e-001
|
||||
5.7385241e-001
|
||||
5.8454032e-001
|
||||
5.9511231e-001
|
||||
6.0557835e-001
|
||||
6.1591099e-001
|
||||
6.2612427e-001
|
||||
6.3619801e-001
|
||||
6.4612697e-001
|
||||
6.5590163e-001
|
||||
6.6551399e-001
|
||||
6.7496632e-001
|
||||
6.8423533e-001
|
||||
6.9332824e-001
|
||||
7.0223887e-001
|
||||
7.1094104e-001
|
||||
7.1944626e-001
|
||||
7.2774489e-001
|
||||
7.3582118e-001
|
||||
7.4368279e-001
|
||||
7.5131375e-001
|
||||
7.5870808e-001
|
||||
7.6586749e-001
|
||||
7.7277809e-001
|
||||
7.7942875e-001
|
||||
7.8583531e-001
|
||||
7.9197358e-001
|
||||
7.9784664e-001
|
||||
8.0344858e-001
|
||||
8.0876950e-001
|
||||
8.1381913e-001
|
||||
8.1857760e-001
|
||||
8.2304199e-001
|
||||
8.2722753e-001
|
||||
8.3110385e-001
|
||||
8.3469374e-001
|
||||
8.3797173e-001
|
||||
8.4095414e-001
|
||||
8.4362383e-001
|
||||
8.4598185e-001
|
||||
8.4803158e-001
|
||||
8.4978052e-001
|
||||
8.5119715e-001
|
||||
8.5230470e-001
|
||||
8.5310209e-001
|
||||
8.5357206e-001
|
||||
8.5373856e-001
|
||||
8.5357206e-001
|
||||
8.5310209e-001
|
||||
8.5230470e-001
|
||||
8.5119715e-001
|
||||
8.4978052e-001
|
||||
8.4803158e-001
|
||||
8.4598185e-001
|
||||
8.4362383e-001
|
||||
8.4095414e-001
|
||||
8.3797173e-001
|
||||
8.3469374e-001
|
||||
8.3110385e-001
|
||||
8.2722753e-001
|
||||
8.2304199e-001
|
||||
8.1857760e-001
|
||||
8.1381913e-001
|
||||
8.0876950e-001
|
||||
8.0344858e-001
|
||||
7.9784664e-001
|
||||
7.9197358e-001
|
||||
7.8583531e-001
|
||||
7.7942875e-001
|
||||
7.7277809e-001
|
||||
7.6586749e-001
|
||||
7.5870808e-001
|
||||
7.5131375e-001
|
||||
7.4368279e-001
|
||||
7.3582118e-001
|
||||
7.2774489e-001
|
||||
7.1944626e-001
|
||||
7.1094104e-001
|
||||
7.0223887e-001
|
||||
6.9332824e-001
|
||||
6.8423533e-001
|
||||
6.7496632e-001
|
||||
6.6551399e-001
|
||||
6.5590163e-001
|
||||
6.4612697e-001
|
||||
6.3619801e-001
|
||||
6.2612427e-001
|
||||
6.1591099e-001
|
||||
6.0557835e-001
|
||||
5.9511231e-001
|
||||
5.8454032e-001
|
||||
5.7385241e-001
|
||||
5.6307891e-001
|
||||
5.5220513e-001
|
||||
5.4125534e-001
|
||||
5.3022409e-001
|
||||
5.1912350e-001
|
||||
5.0798175e-001
|
||||
4.9677083e-001
|
||||
4.8552531e-001
|
||||
4.7424532e-001
|
||||
4.6293081e-001
|
||||
4.5159965e-001
|
||||
4.4025538e-001
|
||||
4.2891199e-001
|
||||
4.1756969e-001
|
||||
4.0623177e-001
|
||||
3.9492118e-001
|
||||
3.8363500e-001
|
||||
3.7237955e-001
|
||||
3.6115899e-001
|
||||
3.4999141e-001
|
||||
3.3887227e-001
|
||||
3.2781137e-001
|
||||
3.1682789e-001
|
||||
3.0590986e-001
|
||||
2.9507167e-001
|
||||
2.8432142e-001
|
||||
2.7366340e-001
|
||||
2.6310533e-001
|
||||
2.5264803e-001
|
||||
2.4230169e-001
|
||||
2.3206909e-001
|
||||
2.2196527e-001
|
||||
2.1197359e-001
|
||||
2.0212502e-001
|
||||
1.9239667e-001
|
||||
1.8281725e-001
|
||||
1.7338082e-001
|
||||
1.6409589e-001
|
||||
1.5496071e-001
|
||||
1.4597665e-001
|
||||
1.3715518e-001
|
||||
1.2850029e-001
|
||||
1.2000780e-001
|
||||
1.1168269e-001
|
||||
1.0353295e-001
|
||||
9.5553335e-002
|
||||
8.7754754e-002
|
||||
8.0137293e-002
|
||||
7.2694330e-002
|
||||
6.5440985e-002
|
||||
5.8370533e-002
|
||||
5.1480418e-002
|
||||
4.4780682e-002
|
||||
3.8277657e-002
|
||||
3.1953127e-002
|
||||
2.5822729e-002
|
||||
1.9883413e-002
|
||||
1.4128883e-002
|
||||
8.5711749e-003
|
||||
3.2086897e-003
|
||||
-1.9765601e-003
|
||||
-6.9636862e-003
|
||||
-1.1762383e-002
|
||||
-1.6370126e-002
|
||||
-2.0799707e-002
|
||||
-2.5030756e-002
|
||||
-2.9082401e-002
|
||||
-3.2958393e-002
|
||||
-3.6641812e-002
|
||||
-4.0145828e-002
|
||||
-4.3476878e-002
|
||||
-4.6630331e-002
|
||||
-4.9597868e-002
|
||||
-5.2409382e-002
|
||||
-5.5046003e-002
|
||||
-5.7515269e-002
|
||||
-5.9816657e-002
|
||||
-6.1960278e-002
|
||||
-6.3944481e-002
|
||||
-6.5769067e-002
|
||||
-6.7452502e-002
|
||||
-6.8966401e-002
|
||||
-7.0353307e-002
|
||||
-7.1582636e-002
|
||||
-7.2677464e-002
|
||||
-7.3640601e-002
|
||||
-7.4466439e-002
|
||||
-7.5157626e-002
|
||||
-7.5730576e-002
|
||||
-7.6174832e-002
|
||||
-7.6505072e-002
|
||||
-7.6720492e-002
|
||||
-7.6823001e-002
|
||||
-7.6817398e-002
|
||||
-7.6709349e-002
|
||||
-7.6499217e-002
|
||||
-7.6199248e-002
|
||||
-7.5800836e-002
|
||||
-7.5313734e-002
|
||||
-7.4745256e-002
|
||||
-7.4100364e-002
|
||||
-7.3362026e-002
|
||||
-7.2568258e-002
|
||||
-7.1700267e-002
|
||||
-7.0762871e-002
|
||||
-6.9763024e-002
|
||||
-6.8704383e-002
|
||||
-6.7607599e-002
|
||||
-6.6436751e-002
|
||||
-6.5224711e-002
|
||||
-6.3971590e-002
|
||||
-6.2685781e-002
|
||||
-6.1345517e-002
|
||||
-5.9983748e-002
|
||||
-5.8591568e-002
|
||||
-5.7161645e-002
|
||||
-5.5717365e-002
|
||||
-5.4245277e-002
|
||||
-5.2763075e-002
|
||||
-5.1255616e-002
|
||||
-4.9738576e-002
|
||||
-4.8216572e-002
|
||||
-4.6684303e-002
|
||||
-4.5148841e-002
|
||||
-4.3609754e-002
|
||||
-4.2064909e-002
|
||||
-4.0534917e-002
|
||||
-3.9005368e-002
|
||||
-3.7481285e-002
|
||||
-3.5969756e-002
|
||||
-3.4462095e-002
|
||||
-3.2975408e-002
|
||||
-3.1501761e-002
|
||||
-3.0050266e-002
|
||||
-2.8607217e-002
|
||||
-2.7185943e-002
|
||||
-2.5787585e-002
|
||||
-2.4416099e-002
|
||||
-2.3068017e-002
|
||||
-2.1746755e-002
|
||||
-2.0453179e-002
|
||||
-1.9187243e-002
|
||||
-1.7943338e-002
|
||||
-1.6732471e-002
|
||||
-1.5540555e-002
|
||||
-1.4390467e-002
|
||||
-1.3271822e-002
|
||||
-1.2185000e-002
|
||||
-1.1131555e-002
|
||||
-1.0115022e-002
|
||||
-9.1325330e-003
|
||||
-8.1798233e-003
|
||||
-7.2615817e-003
|
||||
-6.3792293e-003
|
||||
-5.5337211e-003
|
||||
-4.7222596e-003
|
||||
-3.9401124e-003
|
||||
-3.1933778e-003
|
||||
-2.4826724e-003
|
||||
-1.8039473e-003
|
||||
-1.1568136e-003
|
||||
-5.4642809e-004
|
||||
2.7604519e-005
|
||||
5.8322642e-004
|
||||
1.0902329e-003
|
||||
1.5784683e-003
|
||||
2.0274176e-003
|
||||
2.4508540e-003
|
||||
2.8446758e-003
|
||||
3.2091886e-003
|
||||
3.5401247e-003
|
||||
3.8456408e-003
|
||||
4.1251642e-003
|
||||
4.3801862e-003
|
||||
4.6039530e-003
|
||||
4.8109469e-003
|
||||
4.9839688e-003
|
||||
5.1382275e-003
|
||||
5.2715759e-003
|
||||
5.3838976e-003
|
||||
5.4753783e-003
|
||||
5.5404364e-003
|
||||
5.5917129e-003
|
||||
5.6266114e-003
|
||||
5.6389200e-003
|
||||
5.6455197e-003
|
||||
5.6220643e-003
|
||||
5.5938023e-003
|
||||
5.5475715e-003
|
||||
5.4876040e-003
|
||||
5.4196776e-003
|
||||
5.3471681e-003
|
||||
5.2461166e-003
|
||||
5.1407354e-003
|
||||
5.0393023e-003
|
||||
4.9137604e-003
|
||||
4.7932561e-003
|
||||
4.6606461e-003
|
||||
4.5209853e-003
|
||||
4.3730720e-003
|
||||
4.2264269e-003
|
||||
4.0819753e-003
|
||||
3.9207432e-003
|
||||
3.7603923e-003
|
||||
3.6008268e-003
|
||||
3.4418874e-003
|
||||
3.2739613e-003
|
||||
3.1125421e-003
|
||||
2.9469448e-003
|
||||
2.7870464e-003
|
||||
2.6201759e-003
|
||||
2.4625617e-003
|
||||
2.3017255e-003
|
||||
2.1461584e-003
|
||||
1.9841141e-003
|
||||
1.8348265e-003
|
||||
1.6868083e-003
|
||||
1.5443220e-003
|
||||
1.3902495e-003
|
||||
1.2577885e-003
|
||||
1.1250155e-003
|
||||
9.8859883e-004
|
||||
8.6084433e-004
|
||||
7.4580259e-004
|
||||
6.2393761e-004
|
||||
5.1073885e-004
|
||||
4.0265402e-004
|
||||
2.9495311e-004
|
||||
2.0430171e-004
|
||||
1.0943831e-004
|
||||
1.3494974e-005
|
||||
-6.1733441e-005
|
||||
-1.4463809e-004
|
||||
-2.0983373e-004
|
||||
-2.8969812e-004
|
||||
-3.5011759e-004
|
||||
-4.0951215e-004
|
||||
-4.6063255e-004
|
||||
-5.1455722e-004
|
||||
-5.5645764e-004
|
||||
-5.9461189e-004
|
||||
-6.3415949e-004
|
||||
-6.6504151e-004
|
||||
-6.9179375e-004
|
||||
-7.2153920e-004
|
||||
-7.3193572e-004
|
||||
-7.5300014e-004
|
||||
-7.6307936e-004
|
||||
-7.7579773e-004
|
||||
-7.8014496e-004
|
||||
-7.8036647e-004
|
||||
-7.7798695e-004
|
||||
-7.8343323e-004
|
||||
-7.7248486e-004
|
||||
-7.6813719e-004
|
||||
-7.4905981e-004
|
||||
-7.4409419e-004
|
||||
-7.2550431e-004
|
||||
-7.1577365e-004
|
||||
-6.9416146e-004
|
||||
-6.7776908e-004
|
||||
-6.5403334e-004
|
||||
-6.3124935e-004
|
||||
-6.1327474e-004
|
||||
-5.8709305e-004
|
||||
-5.6778026e-004
|
||||
-5.4665656e-004
|
||||
-5.2265643e-004
|
||||
-5.0407143e-004
|
||||
-4.8937912e-004
|
||||
-4.8752280e-004
|
||||
-4.9475181e-004
|
||||
-5.6176926e-004
|
||||
-5.5252865e-004
|
||||
@@ -0,0 +1,102 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Stretch2d(torch.nn.Module):
|
||||
def __init__(self, x_scale, y_scale, mode="nearest"):
|
||||
super().__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x (Tensor): Input tensor (B, C, F, T).
|
||||
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
||||
"""
|
||||
return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
|
||||
|
||||
|
||||
class UpsampleNetwork(torch.nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
upsample_factors,
|
||||
nonlinear_activation=None,
|
||||
nonlinear_activation_params={},
|
||||
interpolate_mode="nearest",
|
||||
freq_axis_kernel_size=1,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_causal_conv = use_causal_conv
|
||||
self.up_layers = torch.nn.ModuleList()
|
||||
for scale in upsample_factors:
|
||||
# interpolation layer
|
||||
stretch = Stretch2d(scale, 1, interpolate_mode)
|
||||
self.up_layers += [stretch]
|
||||
|
||||
# conv layer
|
||||
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
|
||||
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
||||
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
||||
if use_causal_conv:
|
||||
padding = (freq_axis_padding, scale * 2)
|
||||
else:
|
||||
padding = (freq_axis_padding, scale)
|
||||
conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
||||
self.up_layers += [conv]
|
||||
|
||||
# nonlinear
|
||||
if nonlinear_activation is not None:
|
||||
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||
self.up_layers += [nonlinear]
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c : (B, C, T_in).
|
||||
Tensor: (B, C, T_upsample)
|
||||
"""
|
||||
c = c.unsqueeze(1) # (B, 1, C, T)
|
||||
for f in self.up_layers:
|
||||
c = f(c)
|
||||
return c.squeeze(1) # (B, C, T')
|
||||
|
||||
|
||||
class ConvUpsample(torch.nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
upsample_factors,
|
||||
nonlinear_activation=None,
|
||||
nonlinear_activation_params={},
|
||||
interpolate_mode="nearest",
|
||||
freq_axis_kernel_size=1,
|
||||
aux_channels=80,
|
||||
aux_context_window=0,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.aux_context_window = aux_context_window
|
||||
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
||||
# To capture wide-context information in conditional features
|
||||
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
||||
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
|
||||
self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
|
||||
self.upsample = UpsampleNetwork(
|
||||
upsample_factors=upsample_factors,
|
||||
nonlinear_activation=nonlinear_activation,
|
||||
nonlinear_activation_params=nonlinear_activation_params,
|
||||
interpolate_mode=interpolate_mode,
|
||||
freq_axis_kernel_size=freq_axis_kernel_size,
|
||||
use_causal_conv=use_causal_conv,
|
||||
)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c : (B, C, T_in).
|
||||
Tensor: (B, C, T_upsampled),
|
||||
"""
|
||||
c_ = self.conv_in(c)
|
||||
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
|
||||
return self.upsample(c)
|
||||
@@ -0,0 +1,166 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
nn.init.orthogonal_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Positional encoding with noise level conditioning"""
|
||||
|
||||
def __init__(self, n_channels, max_len=10000):
|
||||
super().__init__()
|
||||
self.n_channels = n_channels
|
||||
self.max_len = max_len
|
||||
self.C = 5000
|
||||
self.pe = torch.zeros(0, 0)
|
||||
|
||||
def forward(self, x, noise_level):
|
||||
if x.shape[2] > self.pe.shape[1]:
|
||||
self.init_pe_matrix(x.shape[1], x.shape[2], x)
|
||||
return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C
|
||||
|
||||
def init_pe_matrix(self, n_channels, max_len, x):
|
||||
pe = torch.zeros(max_len, n_channels)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
|
||||
|
||||
pe[:, 0::2] = torch.sin(position / div_term)
|
||||
pe[:, 1::2] = torch.cos(position / div_term)
|
||||
self.pe = pe.transpose(0, 1).to(x)
|
||||
|
||||
|
||||
class FiLM(nn.Module):
|
||||
def __init__(self, input_size, output_size):
|
||||
super().__init__()
|
||||
self.encoding = PositionalEncoding(input_size)
|
||||
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
|
||||
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
|
||||
|
||||
nn.init.xavier_uniform_(self.input_conv.weight)
|
||||
nn.init.xavier_uniform_(self.output_conv.weight)
|
||||
nn.init.zeros_(self.input_conv.bias)
|
||||
nn.init.zeros_(self.output_conv.bias)
|
||||
|
||||
def forward(self, x, noise_scale):
|
||||
o = self.input_conv(x)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.encoding(o, noise_scale)
|
||||
shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
|
||||
return shift, scale
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.input_conv, "weight")
|
||||
remove_parametrizations(self.output_conv, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.input_conv = weight_norm(self.input_conv)
|
||||
self.output_conv = weight_norm(self.output_conv)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def shif_and_scale(x, scale, shift):
|
||||
o = shift + scale * x
|
||||
return o
|
||||
|
||||
|
||||
class UBlock(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, factor, dilation):
|
||||
super().__init__()
|
||||
assert isinstance(dilation, (list, tuple))
|
||||
assert len(dilation) == 4
|
||||
|
||||
self.factor = factor
|
||||
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||
self.main_block = nn.ModuleList(
|
||||
[
|
||||
Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]),
|
||||
]
|
||||
)
|
||||
self.out_block = nn.ModuleList(
|
||||
[
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, shift, scale):
|
||||
x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
|
||||
res = self.res_block(x_inter)
|
||||
o = F.leaky_relu(x_inter, 0.2)
|
||||
o = F.interpolate(o, size=x.shape[-1] * self.factor)
|
||||
o = self.main_block[0](o)
|
||||
o = shif_and_scale(o, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.main_block[1](o)
|
||||
res2 = res + o
|
||||
o = shif_and_scale(res2, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.out_block[0](o)
|
||||
o = shif_and_scale(o, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.out_block[1](o)
|
||||
o = o + res2
|
||||
return o
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.res_block, "weight")
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
remove_parametrizations(layer, "weight")
|
||||
for _, layer in enumerate(self.out_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
remove_parametrizations(layer, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
for idx, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.main_block[idx] = weight_norm(layer)
|
||||
for idx, layer in enumerate(self.out_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.out_block[idx] = weight_norm(layer)
|
||||
|
||||
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, factor):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||
self.main_block = nn.ModuleList(
|
||||
[
|
||||
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
size = x.shape[-1] // self.factor
|
||||
res = self.res_block(x)
|
||||
res = F.interpolate(res, size=size)
|
||||
o = F.interpolate(x, size=size)
|
||||
for layer in self.main_block:
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = layer(o)
|
||||
return o + res
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.res_block, "weight")
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
remove_parametrizations(layer, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
for idx, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.main_block[idx] = weight_norm(layer)
|
||||
@@ -0,0 +1,154 @@
|
||||
import importlib
|
||||
import re
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(config: Coqpit):
|
||||
"""Load models directly from configuration."""
|
||||
if "discriminator_model" in config and "generator_model" in config:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.gan")
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower())
|
||||
if config.model.lower() == "wavernn":
|
||||
MyModel = getattr(MyModel, "Wavernn")
|
||||
elif config.model.lower() == "gan":
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
elif config.model.lower() == "wavegrad":
|
||||
MyModel = getattr(MyModel, "Wavegrad")
|
||||
else:
|
||||
try:
|
||||
MyModel = getattr(MyModel, to_camel(config.model))
|
||||
except ModuleNotFoundError as e:
|
||||
raise ValueError(f"Model {config.model} not exist!") from e
|
||||
print(" > Vocoder Model: {}".format(config.model))
|
||||
return MyModel.init_from_config(config)
|
||||
|
||||
|
||||
def setup_generator(c):
|
||||
"""TODO: use config object as arguments"""
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
# this is to preserve the Wavernn class name (instead of Wavernn)
|
||||
if c.generator_model.lower() in "hifigan_generator":
|
||||
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params)
|
||||
elif c.generator_model.lower() in "melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model in "melgan_fb_generator":
|
||||
raise ValueError("melgan_fb_generator is now fullband_melgan_generator")
|
||||
elif c.generator_model.lower() in "multiband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "fullband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "parallel_wavegan_generator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
stacks=c.generator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=c.audio["num_mels"],
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
)
|
||||
elif c.generator_model.lower() in "univnet_generator":
|
||||
model = MyModel(**c.generator_model_params)
|
||||
else:
|
||||
raise NotImplementedError(f"Model {c.generator_model} not implemented!")
|
||||
return model
|
||||
|
||||
|
||||
def setup_discriminator(c):
|
||||
"""TODO: use config objekt as arguments"""
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||
if c.discriminator_model in "hifigan_discriminator":
|
||||
model = MyModel()
|
||||
if c.discriminator_model in "random_window_discriminator":
|
||||
model = MyModel(
|
||||
cond_channels=c.audio["num_mels"],
|
||||
hop_length=c.audio["hop_length"],
|
||||
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"],
|
||||
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"],
|
||||
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"],
|
||||
window_sizes=c.discriminator_model_params["window_sizes"],
|
||||
)
|
||||
if c.discriminator_model in "melgan_multiscale_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=c.discriminator_model_params["base_channels"],
|
||||
max_channels=c.discriminator_model_params["max_channels"],
|
||||
downsample_factors=c.discriminator_model_params["downsample_factors"],
|
||||
)
|
||||
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
stacks=c.discriminator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
)
|
||||
if c.discriminator_model == "parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
)
|
||||
if c.discriminator_model == "univnet_discriminator":
|
||||
model = MyModel()
|
||||
return model
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,55 @@
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.model import BaseTrainerModel
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseVocoder(BaseTrainerModel):
|
||||
"""Base `vocoder` class. Every new `vocoder` model must inherit this.
|
||||
|
||||
It defines `vocoder` specific functions on top of `Model`.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
MODEL_TYPE = "vocoder"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
||||
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
if "characters" in config:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.num_chars = num_chars
|
||||
if hasattr(self.config, "model_args"):
|
||||
config.model_args.num_chars = num_chars
|
||||
if "model_args" in config:
|
||||
self.args = self.config.model_args
|
||||
# This is for backward compatibility
|
||||
if "model_params" in config:
|
||||
self.args = self.config.model_params
|
||||
else:
|
||||
self.config = config
|
||||
if "model_args" in config:
|
||||
self.args = self.config.model_args
|
||||
# This is for backward compatibility
|
||||
if "model_params" in config:
|
||||
self.args = self.config.model_params
|
||||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
||||
|
||||
class FullbandMelganGenerator(MelganGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=(2, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=4,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
proj_kernel=proj_kernel,
|
||||
base_channels=base_channels,
|
||||
upsample_factors=upsample_factors,
|
||||
res_kernel=res_kernel,
|
||||
num_res_blocks=num_res_blocks,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
return self.layers(cond_features)
|
||||
@@ -0,0 +1,374 @@
|
||||
from inspect import signature
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from TTS.vocoder.models import setup_discriminator, setup_generator
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
class GAN(BaseVocoder):
|
||||
def __init__(self, config: Coqpit, ap: AudioProcessor = None):
|
||||
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
|
||||
It also helps mixing and matching different generator and disciminator networks easily.
|
||||
|
||||
To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest
|
||||
is handled by the `GAN` class.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.
|
||||
|
||||
Examples:
|
||||
Initializing the GAN model with HifiGAN generator and discriminator.
|
||||
>>> from TTS.vocoder.configs import HifiganConfig
|
||||
>>> config = HifiganConfig()
|
||||
>>> model = GAN(config)
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.model_g = setup_generator(config)
|
||||
self.model_d = setup_discriminator(config)
|
||||
self.train_disc = False # if False, train only the generator.
|
||||
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
|
||||
self.ap = ap
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the generator's forward pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: output of the GAN generator network.
|
||||
"""
|
||||
return self.model_g.forward(x)
|
||||
|
||||
def inference(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the generator's inference pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
Returns:
|
||||
torch.Tensor: output of the GAN generator network.
|
||||
"""
|
||||
return self.model_g.inference(x)
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
|
||||
network on the current pass.
|
||||
|
||||
Args:
|
||||
batch (Dict): Batch of samples returned by the dataloader.
|
||||
criterion (Dict): Criterion used to compute the losses.
|
||||
optimizer_idx (int): ID of the optimizer in use on the current pass.
|
||||
|
||||
Raises:
|
||||
ValueError: `optimizer_idx` is an unexpected value.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: model outputs and the computed loss values.
|
||||
"""
|
||||
outputs = {}
|
||||
loss_dict = {}
|
||||
|
||||
x = batch["input"]
|
||||
y = batch["waveform"]
|
||||
|
||||
if optimizer_idx not in [0, 1]:
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# DISCRIMINATOR optimization
|
||||
|
||||
# generator pass
|
||||
y_hat = self.model_g(x)[:, :, : y.size(2)]
|
||||
|
||||
# cache for generator loss
|
||||
# pylint: disable=W0201
|
||||
self.y_hat_g = y_hat
|
||||
self.y_hat_sub = None
|
||||
self.y_sub_g = None
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
self.y_hat_sub = y_hat
|
||||
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
||||
self.y_hat_g = y_hat # save for generator loss
|
||||
self.y_sub_g = self.model_g.pqmf_analysis(y)
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
|
||||
if self.train_disc:
|
||||
# use different samples for G and D trainings
|
||||
if self.config.diff_samples_for_G_and_D:
|
||||
x_d = batch["input_disc"]
|
||||
y_d = batch["waveform_disc"]
|
||||
# use a different sample than generator
|
||||
with torch.no_grad():
|
||||
y_hat = self.model_g(x_d)
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
||||
else:
|
||||
# use the same samples as generator
|
||||
x_d = x.clone()
|
||||
y_d = y.clone()
|
||||
y_hat = self.y_hat_g
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(self.model_d.forward).parameters) == 2:
|
||||
D_out_fake = self.model_d(y_hat.detach().clone(), x_d)
|
||||
D_out_real = self.model_d(y_d, x_d)
|
||||
else:
|
||||
D_out_fake = self.model_d(y_hat.detach())
|
||||
D_out_real = self.model_d(y_d)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
# self.model_d returns scores and features
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
# model D returns only scores
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
|
||||
outputs = {"model_outputs": y_hat}
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# GENERATOR loss
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if self.train_disc:
|
||||
if len(signature(self.model_d.forward).parameters) == 2:
|
||||
D_out_fake = self.model_d(self.y_hat_g, x)
|
||||
else:
|
||||
D_out_fake = self.model_d(self.y_hat_g)
|
||||
D_out_real = None
|
||||
|
||||
if self.config.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = self.model_d(y)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
feats_fake, feats_real = None, None
|
||||
|
||||
# compute losses
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
|
||||
)
|
||||
outputs = {"model_outputs": self.y_hat_g}
|
||||
return outputs, loss_dict
|
||||
|
||||
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Logging shared by the training and evaluation.
|
||||
|
||||
Args:
|
||||
name (str): Name of the run. `train` or `eval`,
|
||||
ap (AudioProcessor): Audio processor used in training.
|
||||
batch (Dict): Batch used in the last train/eval step.
|
||||
outputs (Dict): Model outputs from the last train/eval step.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: log figures and audio samples.
|
||||
"""
|
||||
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
|
||||
y = batch["waveform"]
|
||||
figures = plot_results(y_hat, y, ap, name)
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
audios = {f"{name}/audio": sample_voice}
|
||||
return figures, audios
|
||||
|
||||
def train_log(
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for training."""
|
||||
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Call `train_step()` with `no_grad()`"""
|
||||
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for evaluation."""
|
||||
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config: Coqpit,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
|
||||
cache: bool = False,
|
||||
) -> None:
|
||||
"""Load a GAN checkpoint and initialize model parameters.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model config.
|
||||
checkpoint_path (str): Checkpoint file path.
|
||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||
"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
# band-aid for older than v0.0.15 GAN models
|
||||
if "model_disc" in state:
|
||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||
else:
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.model_d = None
|
||||
if hasattr(self.model_g, "remove_weight_norm"):
|
||||
self.model_g.remove_weight_norm()
|
||||
|
||||
def on_train_step_start(self, trainer) -> None:
|
||||
"""Enable the discriminator training based on `steps_to_start_discriminator`
|
||||
|
||||
Args:
|
||||
trainer (Trainer): Trainer object.
|
||||
"""
|
||||
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||
|
||||
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
||||
|
||||
Returns:
|
||||
List: optimizers.
|
||||
"""
|
||||
optimizer1 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
|
||||
)
|
||||
optimizer2 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
|
||||
)
|
||||
return [optimizer2, optimizer1]
|
||||
|
||||
def get_lr(self) -> List:
|
||||
"""Set the initial learning rates for each optimizer.
|
||||
|
||||
Returns:
|
||||
List: learning rates for each optimizer.
|
||||
"""
|
||||
return [self.config.lr_disc, self.config.lr_gen]
|
||||
|
||||
def get_scheduler(self, optimizer) -> List:
|
||||
"""Set the schedulers for each optimizer.
|
||||
|
||||
Args:
|
||||
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
|
||||
|
||||
Returns:
|
||||
List: Schedulers, one for each optimizer.
|
||||
"""
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler2, scheduler1]
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: List) -> Dict:
|
||||
"""Format the batch for training.
|
||||
|
||||
Args:
|
||||
batch (List): Batch out of the dataloader.
|
||||
|
||||
Returns:
|
||||
Dict: formatted model inputs.
|
||||
"""
|
||||
if isinstance(batch[0], list):
|
||||
x_G, y_G = batch[0]
|
||||
x_D, y_D = batch[1]
|
||||
return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D}
|
||||
x, y = batch
|
||||
return {"input": x, "waveform": y}
|
||||
|
||||
def get_data_loader( # pylint: disable=no-self-use, unused-argument
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: True,
|
||||
samples: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None, # pylint: disable=unused-argument
|
||||
):
|
||||
"""Initiate and return the GAN dataloader.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model config.
|
||||
ap (AudioProcessor): Audio processor.
|
||||
is_eval (True): Set the dataloader for evaluation if true.
|
||||
samples (List): Data samples.
|
||||
verbose (bool): Log information if true.
|
||||
num_gpus (int): Number of GPUs in use.
|
||||
rank (int): Rank of the current GPU. Defaults to None.
|
||||
|
||||
Returns:
|
||||
DataLoader: Torch dataloader.
|
||||
"""
|
||||
dataset = GANDataset(
|
||||
ap=self.ap,
|
||||
items=samples,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=self.ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||
is_training=not is_eval,
|
||||
return_segments=not is_eval,
|
||||
use_noise_augment=config.use_noise_augment,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_eval else config.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_criterion(self):
|
||||
"""Return criterions for the optimizers"""
|
||||
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
|
||||
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
||||
return GAN(config, ap=ap)
|
||||
@@ -0,0 +1,217 @@
|
||||
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
"""HiFiGAN Periodic Discriminator
|
||||
|
||||
Takes every Pth value from the input waveform and applied a stack of convoluations.
|
||||
|
||||
Note:
|
||||
if `period` is 2
|
||||
`waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat`
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[Tensor]: discriminator scores per sample in the batch.
|
||||
[List[Tensor]]: list of features from each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
x: [B, 1, T]
|
||||
"""
|
||||
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
get_padding = lambda k, d: int((k * d - d) / 2)
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[Tensor]: discriminator scores per sample in the batch.
|
||||
[List[Tensor]]: list of features from each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
x: [B, 1, T]
|
||||
"""
|
||||
feat = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
feat.append(x)
|
||||
x = self.conv_post(x)
|
||||
feat.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, feat
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
"""HiFiGAN Multi-Period Discriminator (MPD)
|
||||
Wrapper for the `PeriodDiscriminator` to apply it in different periods.
|
||||
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[List[Tensor]]: list of scores from each discriminator.
|
||||
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
x: [B, 1, T]
|
||||
"""
|
||||
scores = []
|
||||
feats = []
|
||||
for _, d in enumerate(self.discriminators):
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
return scores, feats
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
"""HiFiGAN Scale Discriminator.
|
||||
It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper.
|
||||
|
||||
Args:
|
||||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
Tensor: discriminator scores.
|
||||
List[Tensor]: list of features from the convolutiona layers.
|
||||
"""
|
||||
feat = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
feat.append(x)
|
||||
x = self.conv_post(x)
|
||||
feat.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
return x, feat
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(torch.nn.Module):
|
||||
"""HiFiGAN Multi-Scale Discriminator.
|
||||
It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
]
|
||||
)
|
||||
self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)])
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores = []
|
||||
feats = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
x = self.meanpools[i - 1](x)
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
return scores, feats
|
||||
|
||||
|
||||
class HifiganDiscriminator(nn.Module):
|
||||
"""HiFiGAN discriminator wrapping MPD and MSD."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.msd = MultiScaleDiscriminator()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores, feats = self.mpd(x)
|
||||
scores_, feats_ = self.msd(x)
|
||||
return scores + scores_, feats + feats_
|
||||
@@ -0,0 +1,301 @@
|
||||
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def get_padding(k, d):
|
||||
return int((k * d - d) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|
||||
|--------------------------------------------------------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input tensor.
|
||||
Returns:
|
||||
Tensor: output tensor.
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
"""
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|
||||
|---------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super().__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class HifiganGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
resblock_type,
|
||||
resblock_dilation_sizes,
|
||||
resblock_kernel_sizes,
|
||||
upsample_kernel_sizes,
|
||||
upsample_initial_channel,
|
||||
upsample_factors,
|
||||
inference_padding=5,
|
||||
cond_channels=0,
|
||||
conv_pre_weight_norm=True,
|
||||
conv_post_weight_norm=True,
|
||||
conv_post_bias=True,
|
||||
):
|
||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||
|
||||
Network:
|
||||
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
|
||||
.. -> zI ---|
|
||||
resblockN_kNx1 -> zN ---'
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
resblock_type (str): type of the `ResBlock`. '1' or '2'.
|
||||
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
|
||||
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
|
||||
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
|
||||
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
|
||||
for each consecutive upsampling layer.
|
||||
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
|
||||
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inference_padding = inference_padding
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_factors)
|
||||
# initial upsampling layers
|
||||
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||
# upsampling layers
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
# MRF blocks
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
# post convolution layer
|
||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||
if cond_channels > 0:
|
||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
if not conv_pre_weight_norm:
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
|
||||
if not conv_post_weight_norm:
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor.
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
o = self.conv_pre(x)
|
||||
if hasattr(self, "cond_layer"):
|
||||
o = o + self.cond_layer(g)
|
||||
for i in range(self.num_upsamples):
|
||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||
o = self.ups[i](o)
|
||||
z_sum = None
|
||||
for j in range(self.num_kernels):
|
||||
if z_sum is None:
|
||||
z_sum = self.resblocks[i * self.num_kernels + j](o)
|
||||
else:
|
||||
z_sum += self.resblocks[i * self.num_kernels + j](o)
|
||||
o = z_sum / self.num_kernels
|
||||
o = F.leaky_relu(o)
|
||||
o = self.conv_post(o)
|
||||
o = torch.tanh(o)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
@@ -0,0 +1,84 @@
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
|
||||
class MelganDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4, 4),
|
||||
groups_denominator=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
layer_kernel_size = np.prod(kernel_sizes)
|
||||
layer_padding = (layer_kernel_size - 1) // 2
|
||||
|
||||
# initial layer
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
]
|
||||
|
||||
# downsampling layers
|
||||
layer_in_channels = base_channels
|
||||
for downsample_factor in downsample_factors:
|
||||
layer_out_channels = min(layer_in_channels * downsample_factor, max_channels)
|
||||
layer_kernel_size = downsample_factor * 10 + 1
|
||||
layer_padding = (layer_kernel_size - 1) // 2
|
||||
layer_groups = layer_in_channels // groups_denominator
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
layer_in_channels,
|
||||
layer_out_channels,
|
||||
kernel_size=layer_kernel_size,
|
||||
stride=downsample_factor,
|
||||
padding=layer_padding,
|
||||
groups=layer_groups,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
]
|
||||
layer_in_channels = layer_out_channels
|
||||
|
||||
# last 2 layers
|
||||
layer_padding1 = (kernel_sizes[0] - 1) // 2
|
||||
layer_padding2 = (kernel_sizes[1] - 1) // 2
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
layer_out_channels,
|
||||
layer_out_channels,
|
||||
kernel_size=kernel_sizes[0],
|
||||
stride=1,
|
||||
padding=layer_padding1,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
feats = []
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
feats.append(x)
|
||||
return x, feats
|
||||
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.melgan import ResidualStack
|
||||
|
||||
|
||||
class MelganGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=(8, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=3,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# assert model parameters
|
||||
assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number."
|
||||
|
||||
# setup additional model parameters
|
||||
base_padding = (proj_kernel - 1) // 2
|
||||
act_slope = 0.2
|
||||
self.inference_padding = 2
|
||||
|
||||
# initial layer
|
||||
layers = []
|
||||
layers += [
|
||||
nn.ReflectionPad1d(base_padding),
|
||||
weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)),
|
||||
]
|
||||
|
||||
# upsampling layers and residual stacks
|
||||
for idx, upsample_factor in enumerate(upsample_factors):
|
||||
layer_in_channels = base_channels // (2**idx)
|
||||
layer_out_channels = base_channels // (2 ** (idx + 1))
|
||||
layer_filter_size = upsample_factor * 2
|
||||
layer_stride = upsample_factor
|
||||
layer_output_padding = upsample_factor % 2
|
||||
layer_padding = upsample_factor // 2 + layer_output_padding
|
||||
layers += [
|
||||
nn.LeakyReLU(act_slope),
|
||||
weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
layer_in_channels,
|
||||
layer_out_channels,
|
||||
layer_filter_size,
|
||||
stride=layer_stride,
|
||||
padding=layer_padding,
|
||||
output_padding=layer_output_padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel),
|
||||
]
|
||||
|
||||
layers += [nn.LeakyReLU(act_slope)]
|
||||
|
||||
# final layer
|
||||
layers += [
|
||||
nn.ReflectionPad1d(base_padding),
|
||||
weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)),
|
||||
nn.Tanh(),
|
||||
]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, c):
|
||||
return self.layers(c)
|
||||
|
||||
def inference(self, c):
|
||||
c = c.to(self.layers[1].weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.layers(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for _, layer in enumerate(self.layers):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.parametrize.remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
@@ -0,0 +1,50 @@
|
||||
from torch import nn
|
||||
|
||||
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
|
||||
|
||||
|
||||
class MelganMultiscaleDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
num_scales=3,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4),
|
||||
pooling_kernel_size=4,
|
||||
pooling_stride=2,
|
||||
pooling_padding=2,
|
||||
groups_denominator=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
MelganDiscriminator(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_sizes=kernel_sizes,
|
||||
base_channels=base_channels,
|
||||
max_channels=max_channels,
|
||||
downsample_factors=downsample_factors,
|
||||
groups_denominator=groups_denominator,
|
||||
)
|
||||
for _ in range(num_scales)
|
||||
]
|
||||
)
|
||||
|
||||
self.pooling = nn.AvgPool1d(
|
||||
kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores = []
|
||||
feats = []
|
||||
for disc in self.discriminators:
|
||||
score, feat = disc(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
x = self.pooling(x)
|
||||
return scores, feats
|
||||
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.layers.pqmf import PQMF
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
||||
|
||||
class MultibandMelganGenerator(MelganGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=(2, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=3,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
proj_kernel=proj_kernel,
|
||||
base_channels=base_channels,
|
||||
upsample_factors=upsample_factors,
|
||||
res_kernel=res_kernel,
|
||||
num_res_blocks=num_res_blocks,
|
||||
)
|
||||
self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
|
||||
|
||||
def pqmf_analysis(self, x):
|
||||
return self.pqmf_layer.analysis(x)
|
||||
|
||||
def pqmf_synthesis(self, x):
|
||||
return self.pqmf_layer.synthesis(x)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
return self.pqmf_synthesis(self.layers(cond_features))
|
||||
@@ -0,0 +1,187 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
|
||||
|
||||
class ParallelWaveganDiscriminator(nn.Module):
|
||||
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
|
||||
It classifies each audio window real/fake and returns a sequence
|
||||
of predictions.
|
||||
It is a stack of convolutional blocks with dilation.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=10,
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size."
|
||||
assert dilation_factor > 0, " [!] dilation factor must be > 0."
|
||||
self.conv_layers = nn.ModuleList()
|
||||
conv_in_channels = in_channels
|
||||
for i in range(num_layers - 1):
|
||||
if i == 0:
|
||||
dilation = 1
|
||||
else:
|
||||
dilation = i if dilation_factor == 1 else dilation_factor**i
|
||||
conv_in_channels = conv_channels
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
conv_layer = [
|
||||
nn.Conv1d(
|
||||
conv_in_channels,
|
||||
conv_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
]
|
||||
self.conv_layers += conv_layer
|
||||
padding = (kernel_size - 1) // 2
|
||||
last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
||||
self.conv_layers += [last_conv_layer]
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x : (B, 1, T).
|
||||
Returns:
|
||||
Tensor: (B, 1, T)
|
||||
"""
|
||||
for f in self.conv_layers:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
class ResidualParallelWaveganDiscriminator(nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
):
|
||||
super().__init__()
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_layers = num_layers
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.res_factor = math.sqrt(1.0 / num_layers)
|
||||
|
||||
# check the number of num_layers and stacks
|
||||
assert num_layers % stacks == 0
|
||||
layers_per_stack = num_layers // stacks
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = nn.Sequential(
|
||||
nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for layer in range(num_layers):
|
||||
dilation = 2 ** (layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
res_channels=res_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=-1,
|
||||
dilation=dilation,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
use_causal_conv=False,
|
||||
)
|
||||
self.conv_layers += [conv]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = nn.ModuleList(
|
||||
[
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (B, 1, T).
|
||||
"""
|
||||
x = self.first_conv(x)
|
||||
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, None)
|
||||
skips += h
|
||||
skips *= self.res_factor
|
||||
|
||||
# apply final layers
|
||||
x = skips
|
||||
for f in self.last_conv_layers:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
print(f"Weight norm is removed from {m}.")
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
@@ -0,0 +1,164 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||
|
||||
|
||||
class ParallelWaveganGenerator(torch.nn.Module):
|
||||
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
|
||||
It is similar to WaveNet with no causal convolution.
|
||||
It is conditioned on an aux feature (spectrogram) to generate
|
||||
an output waveform from an input noise.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=[4, 4, 4, 4],
|
||||
inference_padding=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.aux_channels = aux_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.upsample_factors = upsample_factors
|
||||
self.upsample_scale = np.prod(upsample_factors)
|
||||
self.inference_padding = inference_padding
|
||||
self.use_weight_norm = use_weight_norm
|
||||
|
||||
# check the number of layers and stacks
|
||||
assert num_res_blocks % stacks == 0
|
||||
layers_per_stack = num_res_blocks // stacks
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True)
|
||||
|
||||
# define conv + upsampling network
|
||||
self.upsample_net = ConvUpsample(upsample_factors=upsample_factors)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
for layer in range(num_res_blocks):
|
||||
dilation = 2 ** (layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
res_channels=res_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=aux_channels,
|
||||
dilation=dilation,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
)
|
||||
self.conv_layers += [conv]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c: (B, C ,T').
|
||||
o: Output tensor (B, out_channels, T)
|
||||
"""
|
||||
# random noise
|
||||
x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
|
||||
# perform upsampling
|
||||
if c is not None and self.upsample_net is not None:
|
||||
c = self.upsample_net(c)
|
||||
assert (
|
||||
c.shape[-1] == x.shape[-1]
|
||||
), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"
|
||||
|
||||
# encode to hidden representation
|
||||
x = self.first_conv(x)
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, c)
|
||||
skips += h
|
||||
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
||||
|
||||
# apply final layers
|
||||
x = skips
|
||||
for f in self.last_conv_layers:
|
||||
x = f(x)
|
||||
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
c = c.to(self.first_conv.weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
if self.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
@@ -0,0 +1,203 @@
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, cond_channels, downsample_factor):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.downsample_factor = downsample_factor
|
||||
|
||||
self.start = nn.Sequential(
|
||||
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1),
|
||||
)
|
||||
self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1)
|
||||
self.end = nn.Sequential(
|
||||
nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2)
|
||||
)
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv1d(in_channels, in_channels * 2, kernel_size=1),
|
||||
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
|
||||
)
|
||||
|
||||
def forward(self, inputs, conditions):
|
||||
outputs = self.start(inputs) + self.lc_conv1d(conditions)
|
||||
outputs = self.end(outputs)
|
||||
residual_outputs = self.residual(inputs)
|
||||
outputs = outputs + residual_outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, downsample_factor):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_factor = downsample_factor
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor)
|
||||
self.layers = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2),
|
||||
)
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.downsample_factor > 1:
|
||||
outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs))
|
||||
else:
|
||||
outputs = self.layers(inputs) + self.residual(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class ConditionalDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)):
|
||||
super().__init__()
|
||||
|
||||
assert len(downsample_factors) == len(out_channels) + 1
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.downsample_factors = downsample_factors
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.pre_cond_layers = nn.ModuleList()
|
||||
self.post_cond_layers = nn.ModuleList()
|
||||
|
||||
# layers before condition features
|
||||
self.pre_cond_layers += [DBlock(in_channels, 64, 1)]
|
||||
in_channels = 64
|
||||
for i, channel in enumerate(out_channels):
|
||||
self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i]))
|
||||
in_channels = channel
|
||||
|
||||
# condition block
|
||||
self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1])
|
||||
|
||||
# layers after condition block
|
||||
self.post_cond_layers += [
|
||||
DBlock(in_channels * 2, in_channels * 2, 1),
|
||||
DBlock(in_channels * 2, in_channels * 2, 1),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
nn.Conv1d(in_channels * 2, 1, kernel_size=1),
|
||||
]
|
||||
|
||||
def forward(self, inputs, conditions):
|
||||
batch_size = inputs.size()[0]
|
||||
outputs = inputs.view(batch_size, self.in_channels, -1)
|
||||
for layer in self.pre_cond_layers:
|
||||
outputs = layer(outputs)
|
||||
outputs = self.cond_block(outputs, conditions)
|
||||
for layer in self.post_cond_layers:
|
||||
outputs = layer(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class UnconditionalDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)):
|
||||
super().__init__()
|
||||
|
||||
self.downsample_factors = downsample_factors
|
||||
self.in_channels = in_channels
|
||||
self.downsample_factors = downsample_factors
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers += [DBlock(self.in_channels, base_channels, 1)]
|
||||
in_channels = base_channels
|
||||
for i, factor in enumerate(downsample_factors):
|
||||
self.layers.append(DBlock(in_channels, out_channels[i], factor))
|
||||
in_channels *= 2
|
||||
self.layers += [
|
||||
DBlock(in_channels, in_channels, 1),
|
||||
DBlock(in_channels, in_channels, 1),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
nn.Conv1d(in_channels, 1, kernel_size=1),
|
||||
]
|
||||
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.size()[0]
|
||||
outputs = inputs.view(batch_size, self.in_channels, -1)
|
||||
for layer in self.layers:
|
||||
outputs = layer(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class RandomWindowDiscriminator(nn.Module):
|
||||
"""Random Window Discriminator as described in
|
||||
http://arxiv.org/abs/1909.11646"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cond_channels,
|
||||
hop_length,
|
||||
uncond_disc_donwsample_factors=(8, 4),
|
||||
cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)),
|
||||
cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)),
|
||||
window_sizes=(512, 1024, 2048, 4096, 8192),
|
||||
):
|
||||
super().__init__()
|
||||
self.cond_channels = cond_channels
|
||||
self.window_sizes = window_sizes
|
||||
self.hop_length = hop_length
|
||||
self.base_window_size = self.hop_length * 2
|
||||
self.ks = [ws // self.base_window_size for ws in window_sizes]
|
||||
|
||||
# check arguments
|
||||
assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes)
|
||||
for ws in window_sizes:
|
||||
assert ws % hop_length == 0
|
||||
|
||||
for idx, cf in enumerate(cond_disc_downsample_factors):
|
||||
assert np.prod(cf) == hop_length // self.ks[idx]
|
||||
|
||||
# define layers
|
||||
self.unconditional_discriminators = nn.ModuleList([])
|
||||
for k in self.ks:
|
||||
layer = UnconditionalDiscriminator(
|
||||
in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors
|
||||
)
|
||||
self.unconditional_discriminators.append(layer)
|
||||
|
||||
self.conditional_discriminators = nn.ModuleList([])
|
||||
for idx, k in enumerate(self.ks):
|
||||
layer = ConditionalDiscriminator(
|
||||
in_channels=k,
|
||||
cond_channels=cond_channels,
|
||||
downsample_factors=cond_disc_downsample_factors[idx],
|
||||
out_channels=cond_disc_out_channels[idx],
|
||||
)
|
||||
self.conditional_discriminators.append(layer)
|
||||
|
||||
def forward(self, x, c):
|
||||
scores = []
|
||||
feats = []
|
||||
# unconditional pass
|
||||
for window_size, layer in zip(self.window_sizes, self.unconditional_discriminators):
|
||||
index = np.random.randint(x.shape[-1] - window_size)
|
||||
|
||||
score = layer(x[:, :, index : index + window_size])
|
||||
scores.append(score)
|
||||
|
||||
# conditional pass
|
||||
for window_size, layer in zip(self.window_sizes, self.conditional_discriminators):
|
||||
frame_size = window_size // self.hop_length
|
||||
lc_index = np.random.randint(c.shape[-1] - frame_size)
|
||||
sample_index = lc_index * self.hop_length
|
||||
x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length]
|
||||
c_sub = c[:, :, lc_index : lc_index + frame_size]
|
||||
|
||||
score = layer(x_sub, c_sub)
|
||||
scores.append(score)
|
||||
return scores, feats
|
||||
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class SpecDiscriminator(nn.Module):
|
||||
"""docstring for Discriminator."""
|
||||
|
||||
def __init__(self, fft_size=1024, hop_length=120, win_length=600, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.fft_size = fft_size
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.stft = TorchSTFT(fft_size, hop_length, win_length)
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
|
||||
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, y):
|
||||
fmap = []
|
||||
with torch.no_grad():
|
||||
y = y.squeeze(1)
|
||||
y = self.stft(y)
|
||||
y = y.unsqueeze(1)
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y = d(y)
|
||||
y = F.leaky_relu(y, LRELU_SLOPE)
|
||||
fmap.append(y)
|
||||
|
||||
y = self.out(y)
|
||||
fmap.append(y)
|
||||
|
||||
return torch.flatten(y, 1, -1), fmap
|
||||
|
||||
|
||||
class MultiResSpecDiscriminator(torch.nn.Module):
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann_window"
|
||||
):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
||||
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
||||
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores = []
|
||||
feats = []
|
||||
for d in self.discriminators:
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
|
||||
return scores, feats
|
||||
|
||||
|
||||
class UnivnetDiscriminator(nn.Module):
|
||||
"""Univnet discriminator wrapping MPD and MSD."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.msd = MultiResSpecDiscriminator()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores, feats = self.mpd(x)
|
||||
scores_, feats_ = self.msd(x)
|
||||
return scores + scores_, feats + feats_
|
||||
@@ -0,0 +1,157 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class UnivnetGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int,
|
||||
cond_channels: int,
|
||||
upsample_factors: List[int],
|
||||
lvc_layers_each_block: int,
|
||||
lvc_kernel_size: int,
|
||||
kpnet_hidden_channels: int,
|
||||
kpnet_conv_size: int,
|
||||
dropout: float,
|
||||
use_weight_norm=True,
|
||||
):
|
||||
"""Univnet Generator network.
|
||||
|
||||
Paper: https://arxiv.org/pdf/2106.07889.pdf
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input tensor channels.
|
||||
out_channels (int): Number of channels of the output tensor.
|
||||
hidden_channels (int): Number of hidden network channels.
|
||||
cond_channels (int): Number of channels of the conditioning tensors.
|
||||
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
|
||||
lvc_layers_each_block (int): Number of LVC layers in each block.
|
||||
lvc_kernel_size (int): Kernel size of the LVC layers.
|
||||
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
|
||||
kpnet_conv_size (int): Number of convolution channels in the key-point network.
|
||||
dropout (float): Dropout rate.
|
||||
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.upsample_scale = np.prod(upsample_factors)
|
||||
self.lvc_block_nums = len(upsample_factors)
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = torch.nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
|
||||
)
|
||||
|
||||
# define residual blocks
|
||||
self.lvc_blocks = torch.nn.ModuleList()
|
||||
cond_hop_length = 1
|
||||
for n in range(self.lvc_block_nums):
|
||||
cond_hop_length = cond_hop_length * upsample_factors[n]
|
||||
lvcb = LVCBlock(
|
||||
in_channels=hidden_channels,
|
||||
cond_channels=cond_channels,
|
||||
upsample_ratio=upsample_factors[n],
|
||||
conv_layers=lvc_layers_each_block,
|
||||
conv_kernel_size=lvc_kernel_size,
|
||||
cond_hop_length=cond_hop_length,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=dropout,
|
||||
)
|
||||
self.lvc_blocks += [lvcb]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Conv1d(
|
||||
hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, c):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
c (Tensor): Local conditioning auxiliary features (B, C ,T').
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T)
|
||||
"""
|
||||
# random noise
|
||||
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
x = self.first_conv(x)
|
||||
|
||||
for n in range(self.lvc_block_nums):
|
||||
x = self.lvc_blocks[n](x, c)
|
||||
|
||||
# apply final layers
|
||||
for f in self.last_conv_layers:
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = f(x)
|
||||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization module from all of the layers."""
|
||||
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
parametrize.remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
"""Apply weight normalization module from all of the layers."""
|
||||
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
"""Return receptive field size."""
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""Perform inference.
|
||||
Args:
|
||||
c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`.
|
||||
Returns:
|
||||
Tensor: Output tensor (T, out_channels)
|
||||
"""
|
||||
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
|
||||
c = c.to(next(self.parameters()))
|
||||
return self.forward(c)
|
||||
@@ -0,0 +1,345 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavegradArgs(Coqpit):
|
||||
in_channels: int = 80
|
||||
out_channels: int = 1
|
||||
use_weight_norm: bool = False
|
||||
y_conv_channels: int = 32
|
||||
x_conv_channels: int = 768
|
||||
dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512])
|
||||
ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
|
||||
upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
|
||||
upsample_dilations: List[List[int]] = field(
|
||||
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]]
|
||||
)
|
||||
|
||||
|
||||
class Wavegrad(BaseVocoder):
|
||||
"""🐸 🌊 WaveGrad 🌊 model.
|
||||
Paper - https://arxiv.org/abs/2009.00713
|
||||
|
||||
Examples:
|
||||
Initializing the model.
|
||||
|
||||
>>> from TTS.vocoder.configs import WavegradConfig
|
||||
>>> config = WavegradConfig()
|
||||
>>> model = Wavegrad(config)
|
||||
|
||||
Paper Abstract:
|
||||
This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the
|
||||
data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts
|
||||
from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned
|
||||
on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting
|
||||
the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in
|
||||
terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations.
|
||||
Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive
|
||||
baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations.
|
||||
Audio samples are available at this https URL.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.use_weight_norm = config.model_params.use_weight_norm
|
||||
self.hop_len = np.prod(config.model_params.upsample_factors)
|
||||
self.noise_level = None
|
||||
self.num_steps = None
|
||||
self.beta = None
|
||||
self.alpha = None
|
||||
self.alpha_hat = None
|
||||
self.c1 = None
|
||||
self.c2 = None
|
||||
self.sigma = None
|
||||
|
||||
# dblocks
|
||||
self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2)
|
||||
self.dblocks = nn.ModuleList([])
|
||||
ic = config.model_params.y_conv_channels
|
||||
for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)):
|
||||
self.dblocks.append(DBlock(ic, oc, df))
|
||||
ic = oc
|
||||
|
||||
# film
|
||||
self.film = nn.ModuleList([])
|
||||
ic = config.model_params.y_conv_channels
|
||||
for oc in reversed(config.model_params.ublock_out_channels):
|
||||
self.film.append(FiLM(ic, oc))
|
||||
ic = oc
|
||||
|
||||
# ublocksn
|
||||
self.ublocks = nn.ModuleList([])
|
||||
ic = config.model_params.x_conv_channels
|
||||
for oc, uf, ud in zip(
|
||||
config.model_params.ublock_out_channels,
|
||||
config.model_params.upsample_factors,
|
||||
config.model_params.upsample_dilations,
|
||||
):
|
||||
self.ublocks.append(UBlock(ic, oc, uf, ud))
|
||||
ic = oc
|
||||
|
||||
self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1)
|
||||
self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1)
|
||||
|
||||
if config.model_params.use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x, spectrogram, noise_scale):
|
||||
shift_and_scale = []
|
||||
|
||||
x = self.y_conv(x)
|
||||
shift_and_scale.append(self.film[0](x, noise_scale))
|
||||
|
||||
for film, layer in zip(self.film[1:], self.dblocks):
|
||||
x = layer(x)
|
||||
shift_and_scale.append(film(x, noise_scale))
|
||||
|
||||
x = self.x_conv(spectrogram)
|
||||
for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)):
|
||||
x = layer(x, film_shift, film_scale)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
def load_noise_schedule(self, path):
|
||||
beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg
|
||||
self.compute_noise_level(beta)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, y_n=None):
|
||||
"""
|
||||
Shapes:
|
||||
x: :math:`[B, C , T]`
|
||||
y_n: :math:`[B, 1, T]`
|
||||
"""
|
||||
if y_n is None:
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
|
||||
else:
|
||||
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
|
||||
y_n = y_n.type_as(x)
|
||||
sqrt_alpha_hat = self.noise_level.to(x)
|
||||
for n in range(len(self.alpha) - 1, -1, -1):
|
||||
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
|
||||
if n > 0:
|
||||
z = torch.randn_like(y_n)
|
||||
y_n += self.sigma[n - 1] * z
|
||||
y_n.clamp_(-1.0, 1.0)
|
||||
return y_n
|
||||
|
||||
def compute_y_n(self, y_0):
|
||||
"""Compute noisy audio based on noise schedule"""
|
||||
self.noise_level = self.noise_level.to(y_0)
|
||||
if len(y_0.shape) == 3:
|
||||
y_0 = y_0.squeeze(1)
|
||||
s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]])
|
||||
l_a, l_b = self.noise_level[s], self.noise_level[s + 1]
|
||||
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
|
||||
noise_scale = noise_scale.unsqueeze(1)
|
||||
noise = torch.randn_like(y_0)
|
||||
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise
|
||||
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
|
||||
|
||||
def compute_noise_level(self, beta):
|
||||
"""Compute noise schedule parameters"""
|
||||
self.num_steps = len(beta)
|
||||
alpha = 1 - beta
|
||||
alpha_hat = np.cumprod(alpha)
|
||||
noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0)
|
||||
noise_level = alpha_hat**0.5
|
||||
|
||||
# pylint: disable=not-callable
|
||||
self.beta = torch.tensor(beta.astype(np.float32))
|
||||
self.alpha = torch.tensor(alpha.astype(np.float32))
|
||||
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
|
||||
self.noise_level = torch.tensor(noise_level.astype(np.float32))
|
||||
|
||||
self.c1 = 1 / self.alpha**0.5
|
||||
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5
|
||||
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for _, layer in enumerate(self.dblocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.film):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.ublocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
remove_parametrizations(self.x_conv, "weight")
|
||||
remove_parametrizations(self.out_conv, "weight")
|
||||
remove_parametrizations(self.y_conv, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
for _, layer in enumerate(self.dblocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.film):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.ublocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
self.x_conv = weight_norm(self.x_conv)
|
||||
self.out_conv = weight_norm(self.out_conv)
|
||||
self.y_conv = weight_norm(self.y_conv)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
if self.config.model_params.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
betas = np.linspace(
|
||||
config["test_noise_schedule"]["min_val"],
|
||||
config["test_noise_schedule"]["max_val"],
|
||||
config["test_noise_schedule"]["num_steps"],
|
||||
)
|
||||
self.compute_noise_level(betas)
|
||||
else:
|
||||
betas = np.linspace(
|
||||
config["train_noise_schedule"]["min_val"],
|
||||
config["train_noise_schedule"]["max_val"],
|
||||
config["train_noise_schedule"]["num_steps"],
|
||||
)
|
||||
self.compute_noise_level(betas)
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
# format data
|
||||
x = batch["input"]
|
||||
y = batch["waveform"]
|
||||
|
||||
# set noise scale
|
||||
noise, x_noisy, noise_scale = self.compute_y_n(y)
|
||||
|
||||
# forward pass
|
||||
noise_hat = self.forward(x_noisy, x, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
return {"model_output": noise_hat}, {"loss": loss}
|
||||
|
||||
def train_log( # pylint: disable=no-self-use
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log( # pylint: disable=no-self-use
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
|
||||
# setup noise schedule and inference
|
||||
ap = assets["audio_processor"]
|
||||
noise_schedule = self.config["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
samples = test_loader.dataset.load_test_samples(1)
|
||||
for sample in samples:
|
||||
x = sample[0]
|
||||
x = x[None, :, :].to(next(self.parameters()).device)
|
||||
y = sample[1]
|
||||
y = y[None, :]
|
||||
# compute voice
|
||||
y_pred = self.inference(x)
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_pred, y, ap, "test")
|
||||
# Sample audio
|
||||
sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
|
||||
return figures, {"test/audio": sample_voice}
|
||||
|
||||
def get_optimizer(self):
|
||||
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
|
||||
|
||||
def get_scheduler(self, optimizer):
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
|
||||
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
return torch.nn.L1Loss()
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: Dict) -> Dict:
|
||||
# return a whole audio segment
|
||||
m, y = batch[0], batch[1]
|
||||
y = y.unsqueeze(1)
|
||||
return {"input": m, "waveform": y}
|
||||
|
||||
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
|
||||
ap = assets["audio_processor"]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=samples,
|
||||
seq_len=self.config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=self.config.pad_short,
|
||||
conv_pad=self.config.conv_pad,
|
||||
is_training=not is_eval,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.batch_size,
|
||||
shuffle=num_gpus <= 1,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=unused-argument
|
||||
noise_schedule = self.config["train_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "WavegradConfig"):
|
||||
return Wavegrad(config)
|
||||
@@ -0,0 +1,646 @@
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
|
||||
|
||||
|
||||
def stream(string, variables):
|
||||
sys.stdout.write(f"\r{string}" % variables)
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
# relates https://github.com/pytorch/pytorch/issues/42305
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
||||
self.batch_norm1 = nn.BatchNorm1d(dims)
|
||||
self.batch_norm2 = nn.BatchNorm1d(dims)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.batch_norm1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.batch_norm2(x)
|
||||
return x + residual
|
||||
|
||||
|
||||
class MelResNet(nn.Module):
|
||||
def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
||||
super().__init__()
|
||||
k_size = pad * 2 + 1
|
||||
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
|
||||
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_res_blocks):
|
||||
self.layers.append(ResBlock(compute_dims))
|
||||
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.batch_norm(x)
|
||||
x = F.relu(x)
|
||||
for f in self.layers:
|
||||
x = f(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class Stretch2d(nn.Module):
|
||||
def __init__(self, x_scale, y_scale):
|
||||
super().__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
x = x.unsqueeze(-1).unsqueeze(3)
|
||||
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
|
||||
return x.view(b, c, h * self.y_scale, w * self.x_scale)
|
||||
|
||||
|
||||
class UpsampleNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feat_dims,
|
||||
upsample_scales,
|
||||
compute_dims,
|
||||
num_res_blocks,
|
||||
res_out_dims,
|
||||
pad,
|
||||
use_aux_net,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_scale = np.cumproduct(upsample_scales)[-1]
|
||||
self.indent = pad * self.total_scale
|
||||
self.use_aux_net = use_aux_net
|
||||
if use_aux_net:
|
||||
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
||||
self.resnet_stretch = Stretch2d(self.total_scale, 1)
|
||||
self.up_layers = nn.ModuleList()
|
||||
for scale in upsample_scales:
|
||||
k_size = (1, scale * 2 + 1)
|
||||
padding = (0, scale)
|
||||
stretch = Stretch2d(scale, 1)
|
||||
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
|
||||
conv.weight.data.fill_(1.0 / k_size[1])
|
||||
self.up_layers.append(stretch)
|
||||
self.up_layers.append(conv)
|
||||
|
||||
def forward(self, m):
|
||||
if self.use_aux_net:
|
||||
aux = self.resnet(m).unsqueeze(1)
|
||||
aux = self.resnet_stretch(aux)
|
||||
aux = aux.squeeze(1)
|
||||
aux = aux.transpose(1, 2)
|
||||
else:
|
||||
aux = None
|
||||
m = m.unsqueeze(1)
|
||||
for f in self.up_layers:
|
||||
m = f(m)
|
||||
m = m.squeeze(1)[:, :, self.indent : -self.indent]
|
||||
return m.transpose(1, 2), aux
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.pad = pad
|
||||
self.indent = pad * scale
|
||||
self.use_aux_net = use_aux_net
|
||||
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
||||
|
||||
def forward(self, m):
|
||||
if self.use_aux_net:
|
||||
aux = self.resnet(m)
|
||||
aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True)
|
||||
aux = aux.transpose(1, 2)
|
||||
else:
|
||||
aux = None
|
||||
m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True)
|
||||
m = m[:, :, self.indent : -self.indent]
|
||||
m = m * 0.045 # empirically found
|
||||
|
||||
return m.transpose(1, 2), aux
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavernnArgs(Coqpit):
|
||||
"""🐸 WaveRNN model arguments.
|
||||
|
||||
rnn_dims (int):
|
||||
Number of hidden channels in RNN layers. Defaults to 512.
|
||||
fc_dims (int):
|
||||
Number of hidden channels in fully-conntected layers. Defaults to 512.
|
||||
compute_dims (int):
|
||||
Number of hidden channels in the feature ResNet. Defaults to 128.
|
||||
res_out_dim (int):
|
||||
Number of hidden channels in the feature ResNet output. Defaults to 128.
|
||||
num_res_blocks (int):
|
||||
Number of residual blocks in the ResNet. Defaults to 10.
|
||||
use_aux_net (bool):
|
||||
enable/disable the feature ResNet. Defaults to True.
|
||||
use_upsample_net (bool):
|
||||
enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True.
|
||||
upsample_factors (list):
|
||||
Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```.
|
||||
mode (str):
|
||||
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
|
||||
Gaussian Distribution and `bits` for quantized bits as the model's output.
|
||||
mulaw (bool):
|
||||
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
|
||||
to `True`.
|
||||
pad (int):
|
||||
Padding applied to the input feature frames against the convolution layers of the feature network.
|
||||
Defaults to 2.
|
||||
"""
|
||||
|
||||
rnn_dims: int = 512
|
||||
fc_dims: int = 512
|
||||
compute_dims: int = 128
|
||||
res_out_dims: int = 128
|
||||
num_res_blocks: int = 10
|
||||
use_aux_net: bool = True
|
||||
use_upsample_net: bool = True
|
||||
upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8])
|
||||
mode: str = "mold" # mold [string], gauss [string], bits [int]
|
||||
mulaw: bool = True # apply mulaw if mode is bits
|
||||
pad: int = 2
|
||||
feat_dims: int = 80
|
||||
|
||||
|
||||
class Wavernn(BaseVocoder):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""🐸 WaveRNN model.
|
||||
Original paper - https://arxiv.org/abs/1802.08435
|
||||
Official implementation - https://github.com/fatchord/WaveRNN
|
||||
|
||||
Args:
|
||||
config (Coqpit): [description]
|
||||
|
||||
Raises:
|
||||
RuntimeError: [description]
|
||||
|
||||
Examples:
|
||||
>>> from TTS.vocoder.configs import WavernnConfig
|
||||
>>> config = WavernnConfig()
|
||||
>>> model = Wavernn(config)
|
||||
|
||||
Paper Abstract:
|
||||
Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to
|
||||
both estimating the data distribution and generating high-quality samples. Efficient sampling for this
|
||||
class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we
|
||||
describe a set of general techniques for reducing sampling time while maintaining high output quality.
|
||||
We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that
|
||||
matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it
|
||||
possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight
|
||||
pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of
|
||||
parameters, large sparse networks perform better than small dense networks and this relationship holds for
|
||||
sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample
|
||||
high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on
|
||||
subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple
|
||||
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
|
||||
orthogonal method for increasing sampling efficiency.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if isinstance(self.args.mode, int):
|
||||
self.n_classes = 2**self.args.mode
|
||||
elif self.args.mode == "mold":
|
||||
self.n_classes = 3 * 10
|
||||
elif self.args.mode == "gauss":
|
||||
self.n_classes = 2
|
||||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||
|
||||
self.ap = AudioProcessor(**config.audio.to_dict())
|
||||
self.aux_dims = self.args.res_out_dims // 4
|
||||
|
||||
if self.args.use_upsample_net:
|
||||
assert (
|
||||
np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length
|
||||
), " [!] upsample scales needs to be equal to hop_length"
|
||||
self.upsample = UpsampleNetwork(
|
||||
self.args.feat_dims,
|
||||
self.args.upsample_factors,
|
||||
self.args.compute_dims,
|
||||
self.args.num_res_blocks,
|
||||
self.args.res_out_dims,
|
||||
self.args.pad,
|
||||
self.args.use_aux_net,
|
||||
)
|
||||
else:
|
||||
self.upsample = Upsample(
|
||||
config.audio.hop_length,
|
||||
self.args.pad,
|
||||
self.args.num_res_blocks,
|
||||
self.args.feat_dims,
|
||||
self.args.compute_dims,
|
||||
self.args.res_out_dims,
|
||||
self.args.use_aux_net,
|
||||
)
|
||||
if self.args.use_aux_net:
|
||||
self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims)
|
||||
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims)
|
||||
self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims)
|
||||
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
|
||||
else:
|
||||
self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims)
|
||||
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims)
|
||||
self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims)
|
||||
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
|
||||
|
||||
def forward(self, x, mels):
|
||||
bsize = x.size(0)
|
||||
h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
|
||||
h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
|
||||
mels, aux = self.upsample(mels)
|
||||
|
||||
if self.args.use_aux_net:
|
||||
aux_idx = [self.aux_dims * i for i in range(5)]
|
||||
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
|
||||
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
|
||||
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
|
||||
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
|
||||
|
||||
x = (
|
||||
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
|
||||
if self.args.use_aux_net
|
||||
else torch.cat([x.unsqueeze(-1), mels], dim=2)
|
||||
)
|
||||
x = self.I(x)
|
||||
res = x
|
||||
self.rnn1.flatten_parameters()
|
||||
x, _ = self.rnn1(x, h1)
|
||||
|
||||
x = x + res
|
||||
res = x
|
||||
x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x
|
||||
self.rnn2.flatten_parameters()
|
||||
x, _ = self.rnn2(x, h2)
|
||||
|
||||
x = x + res
|
||||
x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc2(x))
|
||||
return self.fc3(x)
|
||||
|
||||
def inference(self, mels, batched=None, target=None, overlap=None):
|
||||
self.eval()
|
||||
output = []
|
||||
start = time.time()
|
||||
rnn1 = self.get_gru_cell(self.rnn1)
|
||||
rnn2 = self.get_gru_cell(self.rnn2)
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(mels, np.ndarray):
|
||||
mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
|
||||
|
||||
if mels.ndim == 2:
|
||||
mels = mels.unsqueeze(0)
|
||||
wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length
|
||||
|
||||
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both")
|
||||
mels, aux = self.upsample(mels.transpose(1, 2))
|
||||
|
||||
if batched:
|
||||
mels = self.fold_with_overlap(mels, target, overlap)
|
||||
if aux is not None:
|
||||
aux = self.fold_with_overlap(aux, target, overlap)
|
||||
|
||||
b_size, seq_len, _ = mels.size()
|
||||
|
||||
h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
|
||||
h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
|
||||
x = torch.zeros(b_size, 1).type_as(mels)
|
||||
|
||||
if self.args.use_aux_net:
|
||||
d = self.aux_dims
|
||||
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
|
||||
|
||||
for i in range(seq_len):
|
||||
m_t = mels[:, i, :]
|
||||
|
||||
if self.args.use_aux_net:
|
||||
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
|
||||
|
||||
x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1)
|
||||
x = self.I(x)
|
||||
h1 = rnn1(x, h1)
|
||||
|
||||
x = x + h1
|
||||
inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x
|
||||
h2 = rnn2(inp, h2)
|
||||
|
||||
x = x + h2
|
||||
x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc2(x))
|
||||
|
||||
logits = self.fc3(x)
|
||||
|
||||
if self.args.mode == "mold":
|
||||
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif self.args.mode == "gauss":
|
||||
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif isinstance(self.args.mode, int):
|
||||
posterior = F.softmax(logits, dim=1)
|
||||
distrib = torch.distributions.Categorical(posterior)
|
||||
|
||||
sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
|
||||
output.append(sample)
|
||||
x = sample.unsqueeze(-1)
|
||||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||
|
||||
if i % 100 == 0:
|
||||
self.gen_display(i, seq_len, b_size, start)
|
||||
|
||||
output = torch.stack(output).transpose(0, 1)
|
||||
output = output.cpu()
|
||||
if batched:
|
||||
output = output.numpy()
|
||||
output = output.astype(np.float64)
|
||||
|
||||
output = self.xfade_and_unfold(output, target, overlap)
|
||||
else:
|
||||
output = output[0]
|
||||
|
||||
if self.args.mulaw and isinstance(self.args.mode, int):
|
||||
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
|
||||
|
||||
# Fade-out at the end to avoid signal cutting out suddenly
|
||||
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
||||
output = output[:wave_len]
|
||||
|
||||
if wave_len > len(fade_out):
|
||||
output[-20 * self.config.audio.hop_length :] *= fade_out
|
||||
|
||||
self.train()
|
||||
return output
|
||||
|
||||
def gen_display(self, i, seq_len, b_size, start):
|
||||
gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
|
||||
realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate
|
||||
stream(
|
||||
"%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ",
|
||||
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
|
||||
)
|
||||
|
||||
def fold_with_overlap(self, x, target, overlap):
|
||||
"""Fold the tensor with overlap for quick batched inference.
|
||||
Overlap will be used for crossfading in xfade_and_unfold()
|
||||
Args:
|
||||
x (tensor) : Upsampled conditioning features.
|
||||
shape=(1, timesteps, features)
|
||||
target (int) : Target timesteps for each index of batch
|
||||
overlap (int) : Timesteps for both xfade and rnn warmup
|
||||
Return:
|
||||
(tensor) : shape=(num_folds, target + 2 * overlap, features)
|
||||
Details:
|
||||
x = [[h1, h2, ... hn]]
|
||||
Where each h is a vector of conditioning features
|
||||
Eg: target=2, overlap=1 with x.size(1)=10
|
||||
folded = [[h1, h2, h3, h4],
|
||||
[h4, h5, h6, h7],
|
||||
[h7, h8, h9, h10]]
|
||||
"""
|
||||
|
||||
_, total_len, features = x.size()
|
||||
|
||||
# Calculate variables needed
|
||||
num_folds = (total_len - overlap) // (target + overlap)
|
||||
extended_len = num_folds * (overlap + target) + overlap
|
||||
remaining = total_len - extended_len
|
||||
|
||||
# Pad if some time steps poking out
|
||||
if remaining != 0:
|
||||
num_folds += 1
|
||||
padding = target + 2 * overlap - remaining
|
||||
x = self.pad_tensor(x, padding, side="after")
|
||||
|
||||
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
|
||||
|
||||
# Get the values for the folded tensor
|
||||
for i in range(num_folds):
|
||||
start = i * (target + overlap)
|
||||
end = start + target + 2 * overlap
|
||||
folded[i] = x[:, start:end, :]
|
||||
|
||||
return folded
|
||||
|
||||
@staticmethod
|
||||
def get_gru_cell(gru):
|
||||
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
||||
gru_cell.weight_hh.data = gru.weight_hh_l0.data
|
||||
gru_cell.weight_ih.data = gru.weight_ih_l0.data
|
||||
gru_cell.bias_hh.data = gru.bias_hh_l0.data
|
||||
gru_cell.bias_ih.data = gru.bias_ih_l0.data
|
||||
return gru_cell
|
||||
|
||||
@staticmethod
|
||||
def pad_tensor(x, pad, side="both"):
|
||||
# NB - this is just a quick method i need right now
|
||||
# i.e., it won't generalise to other shapes/dims
|
||||
b, t, c = x.size()
|
||||
total = t + 2 * pad if side == "both" else t + pad
|
||||
padded = torch.zeros(b, total, c).to(x.device)
|
||||
if side in ("before", "both"):
|
||||
padded[:, pad : pad + t, :] = x
|
||||
elif side == "after":
|
||||
padded[:, :t, :] = x
|
||||
return padded
|
||||
|
||||
@staticmethod
|
||||
def xfade_and_unfold(y, target, overlap):
|
||||
"""Applies a crossfade and unfolds into a 1d array.
|
||||
Args:
|
||||
y (ndarry) : Batched sequences of audio samples
|
||||
shape=(num_folds, target + 2 * overlap)
|
||||
dtype=np.float64
|
||||
overlap (int) : Timesteps for both xfade and rnn warmup
|
||||
Return:
|
||||
(ndarry) : audio samples in a 1d array
|
||||
shape=(total_len)
|
||||
dtype=np.float64
|
||||
Details:
|
||||
y = [[seq1],
|
||||
[seq2],
|
||||
[seq3]]
|
||||
Apply a gain envelope at both ends of the sequences
|
||||
y = [[seq1_in, seq1_target, seq1_out],
|
||||
[seq2_in, seq2_target, seq2_out],
|
||||
[seq3_in, seq3_target, seq3_out]]
|
||||
Stagger and add up the groups of samples:
|
||||
[seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
|
||||
"""
|
||||
|
||||
num_folds, length = y.shape
|
||||
target = length - 2 * overlap
|
||||
total_len = num_folds * (target + overlap) + overlap
|
||||
|
||||
# Need some silence for the rnn warmup
|
||||
silence_len = overlap // 2
|
||||
fade_len = overlap - silence_len
|
||||
silence = np.zeros((silence_len), dtype=np.float64)
|
||||
|
||||
# Equal power crossfade
|
||||
t = np.linspace(-1, 1, fade_len, dtype=np.float64)
|
||||
fade_in = np.sqrt(0.5 * (1 + t))
|
||||
fade_out = np.sqrt(0.5 * (1 - t))
|
||||
|
||||
# Concat the silence to the fades
|
||||
fade_in = np.concatenate([silence, fade_in])
|
||||
fade_out = np.concatenate([fade_out, silence])
|
||||
|
||||
# Apply the gain to the overlap samples
|
||||
y[:, :overlap] *= fade_in
|
||||
y[:, -overlap:] *= fade_out
|
||||
|
||||
unfolded = np.zeros((total_len), dtype=np.float64)
|
||||
|
||||
# Loop to add up all the samples
|
||||
for i in range(num_folds):
|
||||
start = i * (target + overlap)
|
||||
end = start + target + 2 * overlap
|
||||
unfolded[start:end] += y[i]
|
||||
|
||||
return unfolded
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
mels = batch["input"]
|
||||
waveform = batch["waveform"]
|
||||
waveform_coarse = batch["waveform_coarse"]
|
||||
|
||||
y_hat = self.forward(waveform, mels)
|
||||
if isinstance(self.args.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
waveform_coarse = waveform_coarse.float()
|
||||
waveform_coarse = waveform_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss_dict = criterion(y_hat, waveform_coarse)
|
||||
return {"model_output": y_hat}, loss_dict
|
||||
|
||||
def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
@torch.no_grad()
|
||||
def test(
|
||||
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, Dict]:
|
||||
ap = self.ap
|
||||
figures = {}
|
||||
audios = {}
|
||||
samples = test_loader.dataset.load_test_samples(1)
|
||||
for idx, sample in enumerate(samples):
|
||||
x = torch.FloatTensor(sample[0])
|
||||
x = x.to(next(self.parameters()).device)
|
||||
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
|
||||
x_hat = ap.melspectrogram(y_hat)
|
||||
figures.update(
|
||||
{
|
||||
f"test_{idx}/ground_truth": plot_spectrogram(x.T),
|
||||
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
|
||||
}
|
||||
)
|
||||
audios.update({f"test_{idx}/audio": y_hat})
|
||||
# audios.update({f"real_{idx}/audio": y_hat})
|
||||
return figures, audios
|
||||
|
||||
def test_log(
|
||||
self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
figures, audios = outputs
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: Dict) -> Dict:
|
||||
waveform = batch[0]
|
||||
mels = batch[1]
|
||||
waveform_coarse = batch[2]
|
||||
return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse}
|
||||
|
||||
def get_data_loader( # pylint: disable=no-self-use
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: True,
|
||||
samples: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
):
|
||||
ap = self.ap
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=samples,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=config.model_args.pad,
|
||||
mode=config.model_args.mode,
|
||||
mulaw=config.model_args.mulaw,
|
||||
is_training=not is_eval,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_eval else config.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
collate_fn=dataset.collate,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_criterion(self):
|
||||
# define train functions
|
||||
return WaveRNNLoss(self.args.mode)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "WavernnConfig"):
|
||||
return Wavernn(config)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,154 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions.normal import Normal
|
||||
|
||||
|
||||
def gaussian_loss(y_hat, y, log_std_min=-7.0):
|
||||
assert y_hat.dim() == 3
|
||||
assert y_hat.size(2) == 2
|
||||
mean = y_hat[:, :, :1]
|
||||
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
||||
# TODO: replace with pytorch dist
|
||||
log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)))
|
||||
return log_probs.squeeze().mean()
|
||||
|
||||
|
||||
def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0):
|
||||
assert y_hat.size(2) == 2
|
||||
mean = y_hat[:, :, :1]
|
||||
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
||||
dist = Normal(
|
||||
mean,
|
||||
torch.exp(log_std),
|
||||
)
|
||||
sample = dist.sample()
|
||||
sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor)
|
||||
del dist
|
||||
return sample
|
||||
|
||||
|
||||
def log_sum_exp(x):
|
||||
"""numerically stable log_sum_exp implementation that prevents overflow"""
|
||||
# TF ordering
|
||||
axis = len(x.size()) - 1
|
||||
m, _ = torch.max(x, dim=axis)
|
||||
m2, _ = torch.max(x, dim=axis, keepdim=True)
|
||||
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
|
||||
|
||||
|
||||
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
||||
def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True):
|
||||
if log_scale_min is None:
|
||||
log_scale_min = float(np.log(1e-14))
|
||||
y_hat = y_hat.permute(0, 2, 1)
|
||||
assert y_hat.dim() == 3
|
||||
assert y_hat.size(1) % 3 == 0
|
||||
nr_mix = y_hat.size(1) // 3
|
||||
|
||||
# (B x T x C)
|
||||
y_hat = y_hat.transpose(1, 2)
|
||||
|
||||
# unpack parameters. (B, T, num_mixtures) x 3
|
||||
logit_probs = y_hat[:, :, :nr_mix]
|
||||
means = y_hat[:, :, nr_mix : 2 * nr_mix]
|
||||
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
|
||||
|
||||
# B x T x 1 -> B x T x num_mixtures
|
||||
y = y.expand_as(means)
|
||||
|
||||
centered_y = y - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
|
||||
cdf_plus = torch.sigmoid(plus_in)
|
||||
min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
|
||||
cdf_min = torch.sigmoid(min_in)
|
||||
|
||||
# log probability for edge case of 0 (before scaling)
|
||||
# equivalent: torch.log(F.sigmoid(plus_in))
|
||||
log_cdf_plus = plus_in - F.softplus(plus_in)
|
||||
|
||||
# log probability for edge case of 255 (before scaling)
|
||||
# equivalent: (1 - F.sigmoid(min_in)).log()
|
||||
log_one_minus_cdf_min = -F.softplus(min_in)
|
||||
|
||||
# probability for all other cases
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
mid_in = inv_stdv * centered_y
|
||||
# log probability in the center of the bin, to be used in extreme cases
|
||||
# (not actually used in our code)
|
||||
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
|
||||
|
||||
# tf equivalent
|
||||
|
||||
# log_probs = tf.where(x < -0.999, log_cdf_plus,
|
||||
# tf.where(x > 0.999, log_one_minus_cdf_min,
|
||||
# tf.where(cdf_delta > 1e-5,
|
||||
# tf.log(tf.maximum(cdf_delta, 1e-12)),
|
||||
# log_pdf_mid - np.log(127.5))))
|
||||
|
||||
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
||||
# for num_classes=65536 case? 1e-7? not sure..
|
||||
inner_inner_cond = (cdf_delta > 1e-5).float()
|
||||
|
||||
inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * (
|
||||
log_pdf_mid - np.log((num_classes - 1) / 2)
|
||||
)
|
||||
inner_cond = (y > 0.999).float()
|
||||
inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
|
||||
cond = (y < -0.999).float()
|
||||
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
|
||||
|
||||
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
||||
|
||||
if reduce:
|
||||
return -torch.mean(log_sum_exp(log_probs))
|
||||
return -log_sum_exp(log_probs).unsqueeze(-1)
|
||||
|
||||
|
||||
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
||||
"""
|
||||
Sample from discretized mixture of logistic distributions
|
||||
Args:
|
||||
y (Tensor): :math:`[B, C, T]`
|
||||
log_scale_min (float): Log scale minimum value
|
||||
Returns:
|
||||
Tensor: sample in range of [-1, 1].
|
||||
"""
|
||||
if log_scale_min is None:
|
||||
log_scale_min = float(np.log(1e-14))
|
||||
assert y.size(1) % 3 == 0
|
||||
nr_mix = y.size(1) // 3
|
||||
|
||||
# B x T x C
|
||||
y = y.transpose(1, 2)
|
||||
logit_probs = y[:, :, :nr_mix]
|
||||
|
||||
# sample mixture indicator from softmax
|
||||
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||
temp = logit_probs.data - torch.log(-torch.log(temp))
|
||||
_, argmax = temp.max(dim=-1)
|
||||
|
||||
# (B, T) -> (B, T, nr_mix)
|
||||
one_hot = to_one_hot(argmax, nr_mix)
|
||||
# select logistic parameters
|
||||
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
|
||||
log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
|
||||
# sample from logistic & clip to interval
|
||||
# we don't actually round to the nearest 8bit value when sampling
|
||||
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
|
||||
|
||||
x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, n, fill_with=1.0):
|
||||
# we perform one hot encore with respect to the last axis
|
||||
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor)
|
||||
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
||||
return one_hot
|
||||
@@ -0,0 +1,72 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def interpolate_vocoder_input(scale_factor, spec):
|
||||
"""Interpolate spectrogram by the scale factor.
|
||||
It is mainly used to match the sampling rates of
|
||||
the tts and vocoder models.
|
||||
|
||||
Args:
|
||||
scale_factor (float): scale factor to interpolate the spectrogram
|
||||
spec (np.array): spectrogram to be interpolated
|
||||
|
||||
Returns:
|
||||
torch.tensor: interpolated spectrogram.
|
||||
"""
|
||||
print(" > before interpolation :", spec.shape)
|
||||
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
|
||||
spec = torch.nn.functional.interpolate(
|
||||
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
print(" > after interpolation :", spec.shape)
|
||||
return spec
|
||||
|
||||
|
||||
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
|
||||
"""Plot the predicted and the real waveform and their spectrograms.
|
||||
|
||||
Args:
|
||||
y_hat (torch.tensor): Predicted waveform.
|
||||
y (torch.tensor): Real waveform.
|
||||
ap (AudioProcessor): Audio processor used to process the waveform.
|
||||
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict: output figures keyed by the name of the figures.
|
||||
""" """Plot vocoder model results"""
|
||||
if name_prefix is None:
|
||||
name_prefix = ""
|
||||
|
||||
# select an instance from batch
|
||||
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
||||
y = y[0].squeeze().detach().cpu().numpy()
|
||||
|
||||
spec_fake = ap.melspectrogram(y_hat).T
|
||||
spec_real = ap.melspectrogram(y).T
|
||||
spec_diff = np.abs(spec_fake - spec_real)
|
||||
|
||||
# plot figure and save it
|
||||
fig_wave = plt.figure()
|
||||
plt.subplot(2, 1, 1)
|
||||
plt.plot(y)
|
||||
plt.title("groundtruth speech")
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(y_hat)
|
||||
plt.title("generated speech")
|
||||
plt.tight_layout()
|
||||
plt.close()
|
||||
|
||||
figures = {
|
||||
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
|
||||
name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
|
||||
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
|
||||
name_prefix + "speech_comparison": fig_wave,
|
||||
}
|
||||
return figures
|
||||
Reference in New Issue
Block a user