Add files via upload

This commit is contained in:
Sam Khoze
2024-06-18 13:21:08 -07:00
committed by GitHub
parent 6af40bc6cf
commit dc8b8bca5a
97 changed files with 10910 additions and 0 deletions
+154
View File
@@ -0,0 +1,154 @@
import importlib
import re
from coqpit import Coqpit
def to_camel(text):
text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(config: Coqpit):
"""Load models directly from configuration."""
if "discriminator_model" in config and "generator_model" in config:
MyModel = importlib.import_module("TTS.vocoder.models.gan")
MyModel = getattr(MyModel, "GAN")
else:
MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower())
if config.model.lower() == "wavernn":
MyModel = getattr(MyModel, "Wavernn")
elif config.model.lower() == "gan":
MyModel = getattr(MyModel, "GAN")
elif config.model.lower() == "wavegrad":
MyModel = getattr(MyModel, "Wavegrad")
else:
try:
MyModel = getattr(MyModel, to_camel(config.model))
except ModuleNotFoundError as e:
raise ValueError(f"Model {config.model} not exist!") from e
print(" > Vocoder Model: {}".format(config.model))
return MyModel.init_from_config(config)
def setup_generator(c):
"""TODO: use config object as arguments"""
print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model))
# this is to preserve the Wavernn class name (instead of Wavernn)
if c.generator_model.lower() in "hifigan_generator":
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params)
elif c.generator_model.lower() in "melgan_generator":
model = MyModel(
in_channels=c.audio["num_mels"],
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3,
num_res_blocks=c.generator_model_params["num_res_blocks"],
)
elif c.generator_model in "melgan_fb_generator":
raise ValueError("melgan_fb_generator is now fullband_melgan_generator")
elif c.generator_model.lower() in "multiband_melgan_generator":
model = MyModel(
in_channels=c.audio["num_mels"],
out_channels=4,
proj_kernel=7,
base_channels=384,
upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3,
num_res_blocks=c.generator_model_params["num_res_blocks"],
)
elif c.generator_model.lower() in "fullband_melgan_generator":
model = MyModel(
in_channels=c.audio["num_mels"],
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3,
num_res_blocks=c.generator_model_params["num_res_blocks"],
)
elif c.generator_model.lower() in "parallel_wavegan_generator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=c.generator_model_params["num_res_blocks"],
stacks=c.generator_model_params["stacks"],
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=c.audio["num_mels"],
dropout=0.0,
bias=True,
use_weight_norm=True,
upsample_factors=c.generator_model_params["upsample_factors"],
)
elif c.generator_model.lower() in "univnet_generator":
model = MyModel(**c.generator_model_params)
else:
raise NotImplementedError(f"Model {c.generator_model} not implemented!")
return model
def setup_discriminator(c):
"""TODO: use config objekt as arguments"""
print(" > Discriminator Model: {}".format(c.discriminator_model))
if "parallel_wavegan" in c.discriminator_model:
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
else:
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower())
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
if c.discriminator_model in "hifigan_discriminator":
model = MyModel()
if c.discriminator_model in "random_window_discriminator":
model = MyModel(
cond_channels=c.audio["num_mels"],
hop_length=c.audio["hop_length"],
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"],
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"],
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"],
window_sizes=c.discriminator_model_params["window_sizes"],
)
if c.discriminator_model in "melgan_multiscale_discriminator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_sizes=(5, 3),
base_channels=c.discriminator_model_params["base_channels"],
max_channels=c.discriminator_model_params["max_channels"],
downsample_factors=c.discriminator_model_params["downsample_factors"],
)
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=c.discriminator_model_params["num_layers"],
stacks=c.discriminator_model_params["stacks"],
res_channels=64,
gate_channels=128,
skip_channels=64,
dropout=0.0,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
)
if c.discriminator_model == "parallel_wavegan_discriminator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=c.discriminator_model_params["num_layers"],
conv_channels=64,
dilation_factor=1,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True,
)
if c.discriminator_model == "univnet_discriminator":
model = MyModel()
return model
+55
View File
@@ -0,0 +1,55 @@
from coqpit import Coqpit
from TTS.model import BaseTrainerModel
# pylint: skip-file
class BaseVocoder(BaseTrainerModel):
"""Base `vocoder` class. Every new `vocoder` model must inherit this.
It defines `vocoder` specific functions on top of `Model`.
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
"""
MODEL_TYPE = "vocoder"
def __init__(self, config):
super().__init__()
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
If the config is for the model with a name like "*Args", then we assign the directly.
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars
if "model_args" in config:
self.args = self.config.model_args
# This is for backward compatibility
if "model_params" in config:
self.args = self.config.model_params
else:
self.config = config
if "model_args" in config:
self.args = self.config.model_args
# This is for backward compatibility
if "model_params" in config:
self.args = self.config.model_params
else:
raise ValueError("config must be either a *Config or *Args")
@@ -0,0 +1,33 @@
import torch
from TTS.vocoder.models.melgan_generator import MelganGenerator
class FullbandMelganGenerator(MelganGenerator):
def __init__(
self,
in_channels=80,
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=(2, 8, 2, 2),
res_kernel=3,
num_res_blocks=4,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
proj_kernel=proj_kernel,
base_channels=base_channels,
upsample_factors=upsample_factors,
res_kernel=res_kernel,
num_res_blocks=num_res_blocks,
)
@torch.no_grad()
def inference(self, cond_features):
cond_features = cond_features.to(self.layers[1].weight.device)
cond_features = torch.nn.functional.pad(
cond_features, (self.inference_padding, self.inference_padding), "replicate"
)
return self.layers(cond_features)
+374
View File
@@ -0,0 +1,374 @@
from inspect import signature
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.models import setup_discriminator, setup_generator
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results
class GAN(BaseVocoder):
def __init__(self, config: Coqpit, ap: AudioProcessor = None):
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
It also helps mixing and matching different generator and disciminator networks easily.
To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest
is handled by the `GAN` class.
Args:
config (Coqpit): Model configuration.
ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.
Examples:
Initializing the GAN model with HifiGAN generator and discriminator.
>>> from TTS.vocoder.configs import HifiganConfig
>>> config = HifiganConfig()
>>> model = GAN(config)
"""
super().__init__(config)
self.config = config
self.model_g = setup_generator(config)
self.model_d = setup_discriminator(config)
self.train_disc = False # if False, train only the generator.
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
self.ap = ap
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's forward pass.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: output of the GAN generator network.
"""
return self.model_g.forward(x)
def inference(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's inference pass.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: output of the GAN generator network.
"""
return self.model_g.inference(x)
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
network on the current pass.
Args:
batch (Dict): Batch of samples returned by the dataloader.
criterion (Dict): Criterion used to compute the losses.
optimizer_idx (int): ID of the optimizer in use on the current pass.
Raises:
ValueError: `optimizer_idx` is an unexpected value.
Returns:
Tuple[Dict, Dict]: model outputs and the computed loss values.
"""
outputs = {}
loss_dict = {}
x = batch["input"]
y = batch["waveform"]
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
if optimizer_idx == 0:
# DISCRIMINATOR optimization
# generator pass
y_hat = self.model_g(x)[:, :, : y.size(2)]
# cache for generator loss
# pylint: disable=W0201
self.y_hat_g = y_hat
self.y_hat_sub = None
self.y_sub_g = None
# PQMF formatting
if y_hat.shape[1] > 1:
self.y_hat_sub = y_hat
y_hat = self.model_g.pqmf_synthesis(y_hat)
self.y_hat_g = y_hat # save for generator loss
self.y_sub_g = self.model_g.pqmf_analysis(y)
scores_fake, feats_fake, feats_real = None, None, None
if self.train_disc:
# use different samples for G and D trainings
if self.config.diff_samples_for_G_and_D:
x_d = batch["input_disc"]
y_d = batch["waveform_disc"]
# use a different sample than generator
with torch.no_grad():
y_hat = self.model_g(x_d)
# PQMF formatting
if y_hat.shape[1] > 1:
y_hat = self.model_g.pqmf_synthesis(y_hat)
else:
# use the same samples as generator
x_d = x.clone()
y_d = y.clone()
y_hat = self.y_hat_g
# run D with or without cond. features
if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(y_hat.detach().clone(), x_d)
D_out_real = self.model_d(y_d, x_d)
else:
D_out_fake = self.model_d(y_hat.detach())
D_out_real = self.model_d(y_d)
# format D outputs
if isinstance(D_out_fake, tuple):
# self.model_d returns scores and features
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
scores_real, feats_real = None, None
else:
scores_real, feats_real = D_out_real
else:
# model D returns only scores
scores_fake = D_out_fake
scores_real = D_out_real
# compute losses
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
outputs = {"model_outputs": y_hat}
if optimizer_idx == 1:
# GENERATOR loss
scores_fake, feats_fake, feats_real = None, None, None
if self.train_disc:
if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(self.y_hat_g, x)
else:
D_out_fake = self.model_d(self.y_hat_g)
D_out_real = None
if self.config.use_feat_match_loss:
with torch.no_grad():
D_out_real = self.model_d(y)
# format D outputs
if isinstance(D_out_fake, tuple):
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
feats_real = None
else:
_, feats_real = D_out_real
else:
scores_fake = D_out_fake
feats_fake, feats_real = None, None
# compute losses
loss_dict = criterion[optimizer_idx](
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
)
outputs = {"model_outputs": self.y_hat_g}
return outputs, loss_dict
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
"""Logging shared by the training and evaluation.
Args:
name (str): Name of the run. `train` or `eval`,
ap (AudioProcessor): Audio processor used in training.
batch (Dict): Batch used in the last train/eval step.
outputs (Dict): Model outputs from the last train/eval step.
Returns:
Tuple[Dict, Dict]: log figures and audio samples.
"""
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
y = batch["waveform"]
figures = plot_results(y_hat, y, ap, name)
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
audios = {f"{name}/audio": sample_voice}
return figures, audios
def train_log(
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for training."""
figures, audios = self._log("eval", self.ap, batch, outputs)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Call `train_step()` with `no_grad()`"""
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for evaluation."""
figures, audios = self._log("eval", self.ap, batch, outputs)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint(
self,
config: Coqpit,
checkpoint_path: str,
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
cache: bool = False,
) -> None:
"""Load a GAN checkpoint and initialize model parameters.
Args:
config (Coqpit): Model config.
checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
# band-aid for older than v0.0.15 GAN models
if "model_disc" in state:
self.model_g.load_checkpoint(config, checkpoint_path, eval)
else:
self.load_state_dict(state["model"])
if eval:
self.model_d = None
if hasattr(self.model_g, "remove_weight_norm"):
self.model_g.remove_weight_norm()
def on_train_step_start(self, trainer) -> None:
"""Enable the discriminator training based on `steps_to_start_discriminator`
Args:
trainer (Trainer): Trainer object.
"""
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
Returns:
List: optimizers.
"""
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
)
optimizer2 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
)
return [optimizer2, optimizer1]
def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.
Returns:
List: learning rates for each optimizer.
"""
return [self.config.lr_disc, self.config.lr_gen]
def get_scheduler(self, optimizer) -> List:
"""Set the schedulers for each optimizer.
Args:
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
Returns:
List: Schedulers, one for each optimizer.
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler2, scheduler1]
@staticmethod
def format_batch(batch: List) -> Dict:
"""Format the batch for training.
Args:
batch (List): Batch out of the dataloader.
Returns:
Dict: formatted model inputs.
"""
if isinstance(batch[0], list):
x_G, y_G = batch[0]
x_D, y_D = batch[1]
return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D}
x, y = batch
return {"input": x, "waveform": y}
def get_data_loader( # pylint: disable=no-self-use, unused-argument
self,
config: Coqpit,
assets: Dict,
is_eval: True,
samples: List,
verbose: bool,
num_gpus: int,
rank: int = None, # pylint: disable=unused-argument
):
"""Initiate and return the GAN dataloader.
Args:
config (Coqpit): Model config.
ap (AudioProcessor): Audio processor.
is_eval (True): Set the dataloader for evaluation if true.
samples (List): Data samples.
verbose (bool): Log information if true.
num_gpus (int): Number of GPUs in use.
rank (int): Rank of the current GPU. Defaults to None.
Returns:
DataLoader: Torch dataloader.
"""
dataset = GANDataset(
ap=self.ap,
items=samples,
seq_len=config.seq_len,
hop_len=self.ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
is_training=not is_eval,
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=1 if is_eval else config.batch_size,
shuffle=num_gpus == 0,
drop_last=False,
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
def get_criterion(self):
"""Return criterions for the optimizers"""
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]
@staticmethod
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
ap = AudioProcessor.init_from_config(config, verbose=verbose)
return GAN(config, ap=ap)
+217
View File
@@ -0,0 +1,217 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import torch
from torch import nn
from torch.nn import functional as F
LRELU_SLOPE = 0.1
class DiscriminatorP(torch.nn.Module):
"""HiFiGAN Periodic Discriminator
Takes every Pth value from the input waveform and applied a stack of convoluations.
Note:
if `period` is 2
`waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat`
Args:
x (Tensor): input waveform.
Returns:
[Tensor]: discriminator scores per sample in the batch.
[List[Tensor]]: list of features from each convolutional layer.
Shapes:
x: [B, 1, T]
"""
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super().__init__()
self.period = period
get_padding = lambda k, d: int((k * d - d) / 2)
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
[Tensor]: discriminator scores per sample in the batch.
[List[Tensor]]: list of features from each convolutional layer.
Shapes:
x: [B, 1, T]
"""
feat = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat
class MultiPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN Multi-Period Discriminator (MPD)
Wrapper for the `PeriodDiscriminator` to apply it in different periods.
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
]
)
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
[List[Tensor]]: list of scores from each discriminator.
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
Shapes:
x: [B, 1, T]
"""
scores = []
feats = []
for _, d in enumerate(self.discriminators):
score, feat = d(x)
scores.append(score)
feats.append(feat)
return scores, feats
class DiscriminatorS(torch.nn.Module):
"""HiFiGAN Scale Discriminator.
It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper.
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
Tensor: discriminator scores.
List[Tensor]: list of features from the convolutiona layers.
"""
feat = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat
class MultiScaleDiscriminator(torch.nn.Module):
"""HiFiGAN Multi-Scale Discriminator.
It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper.
"""
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
]
)
self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)])
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores = []
feats = []
for i, d in enumerate(self.discriminators):
if i != 0:
x = self.meanpools[i - 1](x)
score, feat = d(x)
scores.append(score)
feats.append(feat)
return scores, feats
class HifiganDiscriminator(nn.Module):
"""HiFiGAN discriminator wrapping MPD and MSD."""
def __init__(self):
super().__init__()
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiScaleDiscriminator()
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
scores_, feats_ = self.msd(x)
return scores + scores_, feats + feats_
+301
View File
@@ -0,0 +1,301 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1
def get_padding(k, d):
return int((k * d - d) / 2)
class ResBlock1(torch.nn.Module):
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
Network::
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|--------------------------------------------------------------------------------------------------|
Args:
channels (int): number of hidden channels for the convolutional layers.
kernel_size (int): size of the convolution filter in each layer.
dilations (list): list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
]
)
def forward(self, x):
"""
Args:
x (Tensor): input tensor.
Returns:
Tensor: output tensor.
Shapes:
x: [B, C, T]
"""
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_parametrizations(l, "weight")
class ResBlock2(torch.nn.Module):
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
Network::
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|---------------------------------------------------|
Args:
channels (int): number of hidden channels for the convolutional layers.
kernel_size (int): size of the convolution filter in each layer.
dilations (list): list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_parametrizations(l, "weight")
class HifiganGenerator(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
resblock_type,
resblock_dilation_sizes,
resblock_kernel_sizes,
upsample_kernel_sizes,
upsample_initial_channel,
upsample_factors,
inference_padding=5,
cond_channels=0,
conv_pre_weight_norm=True,
conv_post_weight_norm=True,
conv_post_bias=True,
):
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
Network:
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
.. -> zI ---|
resblockN_kNx1 -> zN ---'
Args:
in_channels (int): number of input tensor channels.
out_channels (int): number of output tensor channels.
resblock_type (str): type of the `ResBlock`. '1' or '2'.
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
for each consecutive upsampling layer.
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
"""
super().__init__()
self.inference_padding = inference_padding
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_factors)
# initial upsampling layers
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# MRF blocks
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
if not conv_pre_weight_norm:
remove_parametrizations(self.conv_pre, "weight")
if not conv_post_weight_norm:
remove_parametrizations(self.conv_post, "weight")
def forward(self, x, g=None):
"""
Args:
x (Tensor): feature input tensor.
g (Tensor): global conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
o = self.conv_pre(x)
if hasattr(self, "cond_layer"):
o = o + self.cond_layer(g)
for i in range(self.num_upsamples):
o = F.leaky_relu(o, LRELU_SLOPE)
o = self.ups[i](o)
z_sum = None
for j in range(self.num_kernels):
if z_sum is None:
z_sum = self.resblocks[i * self.num_kernels + j](o)
else:
z_sum += self.resblocks[i * self.num_kernels + j](o)
o = z_sum / self.num_kernels
o = F.leaky_relu(o)
o = self.conv_post(o)
o = torch.tanh(o)
return o
@torch.no_grad()
def inference(self, c):
"""
Args:
x (Tensor): conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.forward(c)
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()
remove_parametrizations(self.conv_pre, "weight")
remove_parametrizations(self.conv_post, "weight")
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()
@@ -0,0 +1,84 @@
import numpy as np
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
class MelganDiscriminator(nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_sizes=(5, 3),
base_channels=16,
max_channels=1024,
downsample_factors=(4, 4, 4, 4),
groups_denominator=4,
):
super().__init__()
self.layers = nn.ModuleList()
layer_kernel_size = np.prod(kernel_sizes)
layer_padding = (layer_kernel_size - 1) // 2
# initial layer
self.layers += [
nn.Sequential(
nn.ReflectionPad1d(layer_padding),
weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)),
nn.LeakyReLU(0.2, inplace=True),
)
]
# downsampling layers
layer_in_channels = base_channels
for downsample_factor in downsample_factors:
layer_out_channels = min(layer_in_channels * downsample_factor, max_channels)
layer_kernel_size = downsample_factor * 10 + 1
layer_padding = (layer_kernel_size - 1) // 2
layer_groups = layer_in_channels // groups_denominator
self.layers += [
nn.Sequential(
weight_norm(
nn.Conv1d(
layer_in_channels,
layer_out_channels,
kernel_size=layer_kernel_size,
stride=downsample_factor,
padding=layer_padding,
groups=layer_groups,
)
),
nn.LeakyReLU(0.2, inplace=True),
)
]
layer_in_channels = layer_out_channels
# last 2 layers
layer_padding1 = (kernel_sizes[0] - 1) // 2
layer_padding2 = (kernel_sizes[1] - 1) // 2
self.layers += [
nn.Sequential(
weight_norm(
nn.Conv1d(
layer_out_channels,
layer_out_channels,
kernel_size=kernel_sizes[0],
stride=1,
padding=layer_padding1,
)
),
nn.LeakyReLU(0.2, inplace=True),
),
weight_norm(
nn.Conv1d(
layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2
)
),
]
def forward(self, x):
feats = []
for layer in self.layers:
x = layer(x)
feats.append(x)
return x, feats
+95
View File
@@ -0,0 +1,95 @@
import torch
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack
class MelganGenerator(nn.Module):
def __init__(
self,
in_channels=80,
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=(8, 8, 2, 2),
res_kernel=3,
num_res_blocks=3,
):
super().__init__()
# assert model parameters
assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number."
# setup additional model parameters
base_padding = (proj_kernel - 1) // 2
act_slope = 0.2
self.inference_padding = 2
# initial layer
layers = []
layers += [
nn.ReflectionPad1d(base_padding),
weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)),
]
# upsampling layers and residual stacks
for idx, upsample_factor in enumerate(upsample_factors):
layer_in_channels = base_channels // (2**idx)
layer_out_channels = base_channels // (2 ** (idx + 1))
layer_filter_size = upsample_factor * 2
layer_stride = upsample_factor
layer_output_padding = upsample_factor % 2
layer_padding = upsample_factor // 2 + layer_output_padding
layers += [
nn.LeakyReLU(act_slope),
weight_norm(
nn.ConvTranspose1d(
layer_in_channels,
layer_out_channels,
layer_filter_size,
stride=layer_stride,
padding=layer_padding,
output_padding=layer_output_padding,
bias=True,
)
),
ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel),
]
layers += [nn.LeakyReLU(act_slope)]
# final layer
layers += [
nn.ReflectionPad1d(base_padding),
weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)),
nn.Tanh(),
]
self.layers = nn.Sequential(*layers)
def forward(self, c):
return self.layers(c)
def inference(self, c):
c = c.to(self.layers[1].weight.device)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.layers(c)
def remove_weight_norm(self):
for _, layer in enumerate(self.layers):
if len(layer.state_dict()) != 0:
try:
nn.utils.parametrize.remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()
@@ -0,0 +1,50 @@
from torch import nn
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
class MelganMultiscaleDiscriminator(nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
num_scales=3,
kernel_sizes=(5, 3),
base_channels=16,
max_channels=1024,
downsample_factors=(4, 4, 4),
pooling_kernel_size=4,
pooling_stride=2,
pooling_padding=2,
groups_denominator=4,
):
super().__init__()
self.discriminators = nn.ModuleList(
[
MelganDiscriminator(
in_channels=in_channels,
out_channels=out_channels,
kernel_sizes=kernel_sizes,
base_channels=base_channels,
max_channels=max_channels,
downsample_factors=downsample_factors,
groups_denominator=groups_denominator,
)
for _ in range(num_scales)
]
)
self.pooling = nn.AvgPool1d(
kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False
)
def forward(self, x):
scores = []
feats = []
for disc in self.discriminators:
score, feat = disc(x)
scores.append(score)
feats.append(feat)
x = self.pooling(x)
return scores, feats
@@ -0,0 +1,41 @@
import torch
from TTS.vocoder.layers.pqmf import PQMF
from TTS.vocoder.models.melgan_generator import MelganGenerator
class MultibandMelganGenerator(MelganGenerator):
def __init__(
self,
in_channels=80,
out_channels=4,
proj_kernel=7,
base_channels=384,
upsample_factors=(2, 8, 2, 2),
res_kernel=3,
num_res_blocks=3,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
proj_kernel=proj_kernel,
base_channels=base_channels,
upsample_factors=upsample_factors,
res_kernel=res_kernel,
num_res_blocks=num_res_blocks,
)
self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
def pqmf_analysis(self, x):
return self.pqmf_layer.analysis(x)
def pqmf_synthesis(self, x):
return self.pqmf_layer.synthesis(x)
@torch.no_grad()
def inference(self, cond_features):
cond_features = cond_features.to(self.layers[1].weight.device)
cond_features = torch.nn.functional.pad(
cond_features, (self.inference_padding, self.inference_padding), "replicate"
)
return self.pqmf_synthesis(self.layers(cond_features))
@@ -0,0 +1,187 @@
import math
import torch
from torch import nn
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
class ParallelWaveganDiscriminator(nn.Module):
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
It classifies each audio window real/fake and returns a sequence
of predictions.
It is a stack of convolutional blocks with dilation.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=10,
conv_channels=64,
dilation_factor=1,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True,
):
super().__init__()
assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size."
assert dilation_factor > 0, " [!] dilation factor must be > 0."
self.conv_layers = nn.ModuleList()
conv_in_channels = in_channels
for i in range(num_layers - 1):
if i == 0:
dilation = 1
else:
dilation = i if dilation_factor == 1 else dilation_factor**i
conv_in_channels = conv_channels
padding = (kernel_size - 1) // 2 * dilation
conv_layer = [
nn.Conv1d(
conv_in_channels,
conv_channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
bias=bias,
),
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
]
self.conv_layers += conv_layer
padding = (kernel_size - 1) // 2
last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
self.conv_layers += [last_conv_layer]
self.apply_weight_norm()
def forward(self, x):
"""
x : (B, 1, T).
Returns:
Tensor: (B, 1, T)
"""
for f in self.conv_layers:
x = f(x)
return x
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
class ResidualParallelWaveganDiscriminator(nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
dropout=0.0,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
):
super().__init__()
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
self.in_channels = in_channels
self.out_channels = out_channels
self.num_layers = num_layers
self.stacks = stacks
self.kernel_size = kernel_size
self.res_factor = math.sqrt(1.0 / num_layers)
# check the number of num_layers and stacks
assert num_layers % stacks == 0
layers_per_stack = num_layers // stacks
# define first convolution
self.first_conv = nn.Sequential(
nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True),
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
)
# define residual blocks
self.conv_layers = nn.ModuleList()
for layer in range(num_layers):
dilation = 2 ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
res_channels=res_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=-1,
dilation=dilation,
dropout=dropout,
bias=bias,
use_causal_conv=False,
)
self.conv_layers += [conv]
# define output layers
self.last_conv_layers = nn.ModuleList(
[
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True),
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True),
]
)
# apply weight norm
self.apply_weight_norm()
def forward(self, x):
"""
x: (B, 1, T).
"""
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, None)
skips += h
skips *= self.res_factor
# apply final layers
x = skips
for f in self.last_conv_layers:
x = f(x)
return x
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
print(f"Weight norm is removed from {m}.")
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
@@ -0,0 +1,164 @@
import math
import numpy as np
import torch
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
class ParallelWaveganGenerator(torch.nn.Module):
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
It is similar to WaveNet with no causal convolution.
It is conditioned on an aux feature (spectrogram) to generate
an output waveform from an input noise.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
bias=True,
use_weight_norm=True,
upsample_factors=[4, 4, 4, 4],
inference_padding=2,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.num_res_blocks = num_res_blocks
self.stacks = stacks
self.kernel_size = kernel_size
self.upsample_factors = upsample_factors
self.upsample_scale = np.prod(upsample_factors)
self.inference_padding = inference_padding
self.use_weight_norm = use_weight_norm
# check the number of layers and stacks
assert num_res_blocks % stacks == 0
layers_per_stack = num_res_blocks // stacks
# define first convolution
self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True)
# define conv + upsampling network
self.upsample_net = ConvUpsample(upsample_factors=upsample_factors)
# define residual blocks
self.conv_layers = torch.nn.ModuleList()
for layer in range(num_res_blocks):
dilation = 2 ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
res_channels=res_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias,
)
self.conv_layers += [conv]
# define output layers
self.last_conv_layers = torch.nn.ModuleList(
[
torch.nn.ReLU(inplace=True),
torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True),
torch.nn.ReLU(inplace=True),
torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True),
]
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, c):
"""
c: (B, C ,T').
o: Output tensor (B, out_channels, T)
"""
# random noise
x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale])
x = x.to(self.first_conv.bias.device)
# perform upsampling
if c is not None and self.upsample_net is not None:
c = self.upsample_net(c)
assert (
c.shape[-1] == x.shape[-1]
), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"
# encode to hidden representation
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, c)
skips += h
skips *= math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
x = skips
for f in self.last_conv_layers:
x = f(x)
return x
@torch.no_grad()
def inference(self, c):
c = c.to(self.first_conv.weight.device)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.forward(c)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
if self.use_weight_norm:
self.remove_weight_norm()
@@ -0,0 +1,203 @@
import numpy as np
from torch import nn
class GBlock(nn.Module):
def __init__(self, in_channels, cond_channels, downsample_factor):
super().__init__()
self.in_channels = in_channels
self.cond_channels = cond_channels
self.downsample_factor = downsample_factor
self.start = nn.Sequential(
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
nn.ReLU(),
nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1),
)
self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1)
self.end = nn.Sequential(
nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2)
)
self.residual = nn.Sequential(
nn.Conv1d(in_channels, in_channels * 2, kernel_size=1),
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
)
def forward(self, inputs, conditions):
outputs = self.start(inputs) + self.lc_conv1d(conditions)
outputs = self.end(outputs)
residual_outputs = self.residual(inputs)
outputs = outputs + residual_outputs
return outputs
class DBlock(nn.Module):
def __init__(self, in_channels, out_channels, downsample_factor):
super().__init__()
self.in_channels = in_channels
self.downsample_factor = downsample_factor
self.out_channels = out_channels
self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor)
self.layers = nn.Sequential(
nn.ReLU(),
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2),
)
self.residual = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=1),
)
def forward(self, inputs):
if self.downsample_factor > 1:
outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs))
else:
outputs = self.layers(inputs) + self.residual(inputs)
return outputs
class ConditionalDiscriminator(nn.Module):
def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)):
super().__init__()
assert len(downsample_factors) == len(out_channels) + 1
self.in_channels = in_channels
self.cond_channels = cond_channels
self.downsample_factors = downsample_factors
self.out_channels = out_channels
self.pre_cond_layers = nn.ModuleList()
self.post_cond_layers = nn.ModuleList()
# layers before condition features
self.pre_cond_layers += [DBlock(in_channels, 64, 1)]
in_channels = 64
for i, channel in enumerate(out_channels):
self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i]))
in_channels = channel
# condition block
self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1])
# layers after condition block
self.post_cond_layers += [
DBlock(in_channels * 2, in_channels * 2, 1),
DBlock(in_channels * 2, in_channels * 2, 1),
nn.AdaptiveAvgPool1d(1),
nn.Conv1d(in_channels * 2, 1, kernel_size=1),
]
def forward(self, inputs, conditions):
batch_size = inputs.size()[0]
outputs = inputs.view(batch_size, self.in_channels, -1)
for layer in self.pre_cond_layers:
outputs = layer(outputs)
outputs = self.cond_block(outputs, conditions)
for layer in self.post_cond_layers:
outputs = layer(outputs)
return outputs
class UnconditionalDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)):
super().__init__()
self.downsample_factors = downsample_factors
self.in_channels = in_channels
self.downsample_factors = downsample_factors
self.out_channels = out_channels
self.layers = nn.ModuleList()
self.layers += [DBlock(self.in_channels, base_channels, 1)]
in_channels = base_channels
for i, factor in enumerate(downsample_factors):
self.layers.append(DBlock(in_channels, out_channels[i], factor))
in_channels *= 2
self.layers += [
DBlock(in_channels, in_channels, 1),
DBlock(in_channels, in_channels, 1),
nn.AdaptiveAvgPool1d(1),
nn.Conv1d(in_channels, 1, kernel_size=1),
]
def forward(self, inputs):
batch_size = inputs.size()[0]
outputs = inputs.view(batch_size, self.in_channels, -1)
for layer in self.layers:
outputs = layer(outputs)
return outputs
class RandomWindowDiscriminator(nn.Module):
"""Random Window Discriminator as described in
http://arxiv.org/abs/1909.11646"""
def __init__(
self,
cond_channels,
hop_length,
uncond_disc_donwsample_factors=(8, 4),
cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)),
cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)),
window_sizes=(512, 1024, 2048, 4096, 8192),
):
super().__init__()
self.cond_channels = cond_channels
self.window_sizes = window_sizes
self.hop_length = hop_length
self.base_window_size = self.hop_length * 2
self.ks = [ws // self.base_window_size for ws in window_sizes]
# check arguments
assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes)
for ws in window_sizes:
assert ws % hop_length == 0
for idx, cf in enumerate(cond_disc_downsample_factors):
assert np.prod(cf) == hop_length // self.ks[idx]
# define layers
self.unconditional_discriminators = nn.ModuleList([])
for k in self.ks:
layer = UnconditionalDiscriminator(
in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors
)
self.unconditional_discriminators.append(layer)
self.conditional_discriminators = nn.ModuleList([])
for idx, k in enumerate(self.ks):
layer = ConditionalDiscriminator(
in_channels=k,
cond_channels=cond_channels,
downsample_factors=cond_disc_downsample_factors[idx],
out_channels=cond_disc_out_channels[idx],
)
self.conditional_discriminators.append(layer)
def forward(self, x, c):
scores = []
feats = []
# unconditional pass
for window_size, layer in zip(self.window_sizes, self.unconditional_discriminators):
index = np.random.randint(x.shape[-1] - window_size)
score = layer(x[:, :, index : index + window_size])
scores.append(score)
# conditional pass
for window_size, layer in zip(self.window_sizes, self.conditional_discriminators):
frame_size = window_size // self.hop_length
lc_index = np.random.randint(c.shape[-1] - frame_size)
sample_index = lc_index * self.hop_length
x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length]
c_sub = c[:, :, lc_index : lc_index + frame_size]
score = layer(x_sub, c_sub)
scores.append(score)
return scores, feats
@@ -0,0 +1,95 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
LRELU_SLOPE = 0.1
class SpecDiscriminator(nn.Module):
"""docstring for Discriminator."""
def __init__(self, fft_size=1024, hop_length=120, win_length=600, use_spectral_norm=False):
super().__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.fft_size = fft_size
self.hop_length = hop_length
self.win_length = win_length
self.stft = TorchSTFT(fft_size, hop_length, win_length)
self.discriminators = nn.ModuleList(
[
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
]
)
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
def forward(self, y):
fmap = []
with torch.no_grad():
y = y.squeeze(1)
y = self.stft(y)
y = y.unsqueeze(1)
for _, d in enumerate(self.discriminators):
y = d(y)
y = F.leaky_relu(y, LRELU_SLOPE)
fmap.append(y)
y = self.out(y)
fmap.append(y)
return torch.flatten(y, 1, -1), fmap
class MultiResSpecDiscriminator(torch.nn.Module):
def __init__( # pylint: disable=dangerous-default-value
self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann_window"
):
super().__init__()
self.discriminators = nn.ModuleList(
[
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window),
]
)
def forward(self, x):
scores = []
feats = []
for d in self.discriminators:
score, feat = d(x)
scores.append(score)
feats.append(feat)
return scores, feats
class UnivnetDiscriminator(nn.Module):
"""Univnet discriminator wrapping MPD and MSD."""
def __init__(self):
super().__init__()
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiResSpecDiscriminator()
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
scores_, feats_ = self.msd(x)
return scores + scores_, feats + feats_
+157
View File
@@ -0,0 +1,157 @@
from typing import List
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils import parametrize
from TTS.vocoder.layers.lvc_block import LVCBlock
LRELU_SLOPE = 0.1
class UnivnetGenerator(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
cond_channels: int,
upsample_factors: List[int],
lvc_layers_each_block: int,
lvc_kernel_size: int,
kpnet_hidden_channels: int,
kpnet_conv_size: int,
dropout: float,
use_weight_norm=True,
):
"""Univnet Generator network.
Paper: https://arxiv.org/pdf/2106.07889.pdf
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of channels of the output tensor.
hidden_channels (int): Number of hidden network channels.
cond_channels (int): Number of channels of the conditioning tensors.
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
lvc_layers_each_block (int): Number of LVC layers in each block.
lvc_kernel_size (int): Kernel size of the LVC layers.
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
kpnet_conv_size (int): Number of convolution channels in the key-point network.
dropout (float): Dropout rate.
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.cond_channels = cond_channels
self.upsample_scale = np.prod(upsample_factors)
self.lvc_block_nums = len(upsample_factors)
# define first convolution
self.first_conv = torch.nn.Conv1d(
in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
)
# define residual blocks
self.lvc_blocks = torch.nn.ModuleList()
cond_hop_length = 1
for n in range(self.lvc_block_nums):
cond_hop_length = cond_hop_length * upsample_factors[n]
lvcb = LVCBlock(
in_channels=hidden_channels,
cond_channels=cond_channels,
upsample_ratio=upsample_factors[n],
conv_layers=lvc_layers_each_block,
conv_kernel_size=lvc_kernel_size,
cond_hop_length=cond_hop_length,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=dropout,
)
self.lvc_blocks += [lvcb]
# define output layers
self.last_conv_layers = torch.nn.ModuleList(
[
torch.nn.Conv1d(
hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
),
]
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, c):
"""Calculate forward propagation.
Args:
c (Tensor): Local conditioning auxiliary features (B, C ,T').
Returns:
Tensor: Output tensor (B, out_channels, T)
"""
# random noise
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
x = x.to(self.first_conv.bias.device)
x = self.first_conv(x)
for n in range(self.lvc_block_nums):
x = self.lvc_blocks[n](x, c)
# apply final layers
for f in self.last_conv_layers:
x = F.leaky_relu(x, LRELU_SLOPE)
x = f(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
parametrize.remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
"""Return receptive field size."""
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
@torch.no_grad()
def inference(self, c):
"""Perform inference.
Args:
c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`.
Returns:
Tensor: Output tensor (T, out_channels)
"""
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
x = x.to(self.first_conv.bias.device)
c = c.to(next(self.parameters()))
return self.forward(c)
+345
View File
@@ -0,0 +1,345 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results
@dataclass
class WavegradArgs(Coqpit):
in_channels: int = 80
out_channels: int = 1
use_weight_norm: bool = False
y_conv_channels: int = 32
x_conv_channels: int = 768
dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512])
ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
upsample_dilations: List[List[int]] = field(
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]]
)
class Wavegrad(BaseVocoder):
"""🐸 🌊 WaveGrad 🌊 model.
Paper - https://arxiv.org/abs/2009.00713
Examples:
Initializing the model.
>>> from TTS.vocoder.configs import WavegradConfig
>>> config = WavegradConfig()
>>> model = Wavegrad(config)
Paper Abstract:
This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the
data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts
from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned
on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting
the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in
terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations.
Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive
baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations.
Audio samples are available at this https URL.
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
super().__init__(config)
self.config = config
self.use_weight_norm = config.model_params.use_weight_norm
self.hop_len = np.prod(config.model_params.upsample_factors)
self.noise_level = None
self.num_steps = None
self.beta = None
self.alpha = None
self.alpha_hat = None
self.c1 = None
self.c2 = None
self.sigma = None
# dblocks
self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2)
self.dblocks = nn.ModuleList([])
ic = config.model_params.y_conv_channels
for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)):
self.dblocks.append(DBlock(ic, oc, df))
ic = oc
# film
self.film = nn.ModuleList([])
ic = config.model_params.y_conv_channels
for oc in reversed(config.model_params.ublock_out_channels):
self.film.append(FiLM(ic, oc))
ic = oc
# ublocksn
self.ublocks = nn.ModuleList([])
ic = config.model_params.x_conv_channels
for oc, uf, ud in zip(
config.model_params.ublock_out_channels,
config.model_params.upsample_factors,
config.model_params.upsample_dilations,
):
self.ublocks.append(UBlock(ic, oc, uf, ud))
ic = oc
self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1)
self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1)
if config.model_params.use_weight_norm:
self.apply_weight_norm()
def forward(self, x, spectrogram, noise_scale):
shift_and_scale = []
x = self.y_conv(x)
shift_and_scale.append(self.film[0](x, noise_scale))
for film, layer in zip(self.film[1:], self.dblocks):
x = layer(x)
shift_and_scale.append(film(x, noise_scale))
x = self.x_conv(spectrogram)
for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)):
x = layer(x, film_shift, film_scale)
x = self.out_conv(x)
return x
def load_noise_schedule(self, path):
beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg
self.compute_noise_level(beta)
@torch.no_grad()
def inference(self, x, y_n=None):
"""
Shapes:
x: :math:`[B, C , T]`
y_n: :math:`[B, 1, T]`
"""
if y_n is None:
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
else:
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
y_n = y_n.type_as(x)
sqrt_alpha_hat = self.noise_level.to(x)
for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
if n > 0:
z = torch.randn_like(y_n)
y_n += self.sigma[n - 1] * z
y_n.clamp_(-1.0, 1.0)
return y_n
def compute_y_n(self, y_0):
"""Compute noisy audio based on noise schedule"""
self.noise_level = self.noise_level.to(y_0)
if len(y_0.shape) == 3:
y_0 = y_0.squeeze(1)
s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]])
l_a, l_b = self.noise_level[s], self.noise_level[s + 1]
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
noise_scale = noise_scale.unsqueeze(1)
noise = torch.randn_like(y_0)
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
def compute_noise_level(self, beta):
"""Compute noise schedule parameters"""
self.num_steps = len(beta)
alpha = 1 - beta
alpha_hat = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0)
noise_level = alpha_hat**0.5
# pylint: disable=not-callable
self.beta = torch.tensor(beta.astype(np.float32))
self.alpha = torch.tensor(alpha.astype(np.float32))
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
self.noise_level = torch.tensor(noise_level.astype(np.float32))
self.c1 = 1 / self.alpha**0.5
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5
def remove_weight_norm(self):
for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0:
try:
remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0:
try:
remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0:
try:
remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
remove_parametrizations(self.x_conv, "weight")
remove_parametrizations(self.out_conv, "weight")
remove_parametrizations(self.y_conv, "weight")
def apply_weight_norm(self):
for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
self.x_conv = weight_norm(self.x_conv)
self.out_conv = weight_norm(self.out_conv)
self.y_conv = weight_norm(self.y_conv)
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
if self.config.model_params.use_weight_norm:
self.remove_weight_norm()
betas = np.linspace(
config["test_noise_schedule"]["min_val"],
config["test_noise_schedule"]["max_val"],
config["test_noise_schedule"]["num_steps"],
)
self.compute_noise_level(betas)
else:
betas = np.linspace(
config["train_noise_schedule"]["min_val"],
config["train_noise_schedule"]["max_val"],
config["train_noise_schedule"]["num_steps"],
)
self.compute_noise_level(betas)
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
# format data
x = batch["input"]
y = batch["waveform"]
# set noise scale
noise, x_noisy, noise_scale = self.compute_y_n(y)
# forward pass
noise_hat = self.forward(x_noisy, x, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
return {"model_output": noise_hat}, {"loss": loss}
def train_log( # pylint: disable=no-self-use
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
pass
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)
def eval_log( # pylint: disable=no-self-use
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> None:
pass
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
# setup noise schedule and inference
ap = assets["audio_processor"]
noise_schedule = self.config["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
samples = test_loader.dataset.load_test_samples(1)
for sample in samples:
x = sample[0]
x = x[None, :, :].to(next(self.parameters()).device)
y = sample[1]
y = y[None, :]
# compute voice
y_pred = self.inference(x)
# compute spectrograms
figures = plot_results(y_pred, y, ap, "test")
# Sample audio
sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
return figures, {"test/audio": sample_voice}
def get_optimizer(self):
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
def get_scheduler(self, optimizer):
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
@staticmethod
def get_criterion():
return torch.nn.L1Loss()
@staticmethod
def format_batch(batch: Dict) -> Dict:
# return a whole audio segment
m, y = batch[0], batch[1]
y = y.unsqueeze(1)
return {"input": m, "waveform": y}
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
ap = assets["audio_processor"]
dataset = WaveGradDataset(
ap=ap,
items=samples,
seq_len=self.config.seq_len,
hop_len=ap.hop_length,
pad_short=self.config.pad_short,
conv_pad=self.config.conv_pad,
is_training=not is_eval,
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=num_gpus <= 1,
drop_last=False,
sampler=sampler,
num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers,
pin_memory=False,
)
return loader
def on_epoch_start(self, trainer): # pylint: disable=unused-argument
noise_schedule = self.config["train_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
@staticmethod
def init_from_config(config: "WavegradConfig"):
return Wavegrad(config)
+646
View File
@@ -0,0 +1,646 @@
import sys
import time
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
def stream(string, variables):
sys.stdout.write(f"\r{string}" % variables)
# pylint: disable=abstract-method
# relates https://github.com/pytorch/pytorch/issues/42305
class ResBlock(nn.Module):
def __init__(self, dims):
super().__init__()
self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
self.batch_norm1 = nn.BatchNorm1d(dims)
self.batch_norm2 = nn.BatchNorm1d(dims)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.batch_norm1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.batch_norm2(x)
return x + residual
class MelResNet(nn.Module):
def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
super().__init__()
k_size = pad * 2 + 1
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
self.batch_norm = nn.BatchNorm1d(compute_dims)
self.layers = nn.ModuleList()
for _ in range(num_res_blocks):
self.layers.append(ResBlock(compute_dims))
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
def forward(self, x):
x = self.conv_in(x)
x = self.batch_norm(x)
x = F.relu(x)
for f in self.layers:
x = f(x)
x = self.conv_out(x)
return x
class Stretch2d(nn.Module):
def __init__(self, x_scale, y_scale):
super().__init__()
self.x_scale = x_scale
self.y_scale = y_scale
def forward(self, x):
b, c, h, w = x.size()
x = x.unsqueeze(-1).unsqueeze(3)
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
return x.view(b, c, h * self.y_scale, w * self.x_scale)
class UpsampleNetwork(nn.Module):
def __init__(
self,
feat_dims,
upsample_scales,
compute_dims,
num_res_blocks,
res_out_dims,
pad,
use_aux_net,
):
super().__init__()
self.total_scale = np.cumproduct(upsample_scales)[-1]
self.indent = pad * self.total_scale
self.use_aux_net = use_aux_net
if use_aux_net:
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
self.resnet_stretch = Stretch2d(self.total_scale, 1)
self.up_layers = nn.ModuleList()
for scale in upsample_scales:
k_size = (1, scale * 2 + 1)
padding = (0, scale)
stretch = Stretch2d(scale, 1)
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
conv.weight.data.fill_(1.0 / k_size[1])
self.up_layers.append(stretch)
self.up_layers.append(conv)
def forward(self, m):
if self.use_aux_net:
aux = self.resnet(m).unsqueeze(1)
aux = self.resnet_stretch(aux)
aux = aux.squeeze(1)
aux = aux.transpose(1, 2)
else:
aux = None
m = m.unsqueeze(1)
for f in self.up_layers:
m = f(m)
m = m.squeeze(1)[:, :, self.indent : -self.indent]
return m.transpose(1, 2), aux
class Upsample(nn.Module):
def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net):
super().__init__()
self.scale = scale
self.pad = pad
self.indent = pad * scale
self.use_aux_net = use_aux_net
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
def forward(self, m):
if self.use_aux_net:
aux = self.resnet(m)
aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True)
aux = aux.transpose(1, 2)
else:
aux = None
m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True)
m = m[:, :, self.indent : -self.indent]
m = m * 0.045 # empirically found
return m.transpose(1, 2), aux
@dataclass
class WavernnArgs(Coqpit):
"""🐸 WaveRNN model arguments.
rnn_dims (int):
Number of hidden channels in RNN layers. Defaults to 512.
fc_dims (int):
Number of hidden channels in fully-conntected layers. Defaults to 512.
compute_dims (int):
Number of hidden channels in the feature ResNet. Defaults to 128.
res_out_dim (int):
Number of hidden channels in the feature ResNet output. Defaults to 128.
num_res_blocks (int):
Number of residual blocks in the ResNet. Defaults to 10.
use_aux_net (bool):
enable/disable the feature ResNet. Defaults to True.
use_upsample_net (bool):
enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True.
upsample_factors (list):
Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```.
mode (str):
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
Gaussian Distribution and `bits` for quantized bits as the model's output.
mulaw (bool):
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
to `True`.
pad (int):
Padding applied to the input feature frames against the convolution layers of the feature network.
Defaults to 2.
"""
rnn_dims: int = 512
fc_dims: int = 512
compute_dims: int = 128
res_out_dims: int = 128
num_res_blocks: int = 10
use_aux_net: bool = True
use_upsample_net: bool = True
upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8])
mode: str = "mold" # mold [string], gauss [string], bits [int]
mulaw: bool = True # apply mulaw if mode is bits
pad: int = 2
feat_dims: int = 80
class Wavernn(BaseVocoder):
def __init__(self, config: Coqpit):
"""🐸 WaveRNN model.
Original paper - https://arxiv.org/abs/1802.08435
Official implementation - https://github.com/fatchord/WaveRNN
Args:
config (Coqpit): [description]
Raises:
RuntimeError: [description]
Examples:
>>> from TTS.vocoder.configs import WavernnConfig
>>> config = WavernnConfig()
>>> model = Wavernn(config)
Paper Abstract:
Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to
both estimating the data distribution and generating high-quality samples. Efficient sampling for this
class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we
describe a set of general techniques for reducing sampling time while maintaining high output quality.
We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that
matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it
possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight
pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of
parameters, large sparse networks perform better than small dense networks and this relationship holds for
sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample
high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on
subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
orthogonal method for increasing sampling efficiency.
"""
super().__init__(config)
if isinstance(self.args.mode, int):
self.n_classes = 2**self.args.mode
elif self.args.mode == "mold":
self.n_classes = 3 * 10
elif self.args.mode == "gauss":
self.n_classes = 2
else:
raise RuntimeError("Unknown model mode value - ", self.args.mode)
self.ap = AudioProcessor(**config.audio.to_dict())
self.aux_dims = self.args.res_out_dims // 4
if self.args.use_upsample_net:
assert (
np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length
), " [!] upsample scales needs to be equal to hop_length"
self.upsample = UpsampleNetwork(
self.args.feat_dims,
self.args.upsample_factors,
self.args.compute_dims,
self.args.num_res_blocks,
self.args.res_out_dims,
self.args.pad,
self.args.use_aux_net,
)
else:
self.upsample = Upsample(
config.audio.hop_length,
self.args.pad,
self.args.num_res_blocks,
self.args.feat_dims,
self.args.compute_dims,
self.args.res_out_dims,
self.args.use_aux_net,
)
if self.args.use_aux_net:
self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims)
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True)
self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims)
self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims)
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
else:
self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims)
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims)
self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims)
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
def forward(self, x, mels):
bsize = x.size(0)
h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
mels, aux = self.upsample(mels)
if self.args.use_aux_net:
aux_idx = [self.aux_dims * i for i in range(5)]
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
x = (
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
if self.args.use_aux_net
else torch.cat([x.unsqueeze(-1), mels], dim=2)
)
x = self.I(x)
res = x
self.rnn1.flatten_parameters()
x, _ = self.rnn1(x, h1)
x = x + res
res = x
x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x
self.rnn2.flatten_parameters()
x, _ = self.rnn2(x, h2)
x = x + res
x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x
x = F.relu(self.fc1(x))
x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x
x = F.relu(self.fc2(x))
return self.fc3(x)
def inference(self, mels, batched=None, target=None, overlap=None):
self.eval()
output = []
start = time.time()
rnn1 = self.get_gru_cell(self.rnn1)
rnn2 = self.get_gru_cell(self.rnn2)
with torch.no_grad():
if isinstance(mels, np.ndarray):
mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
if mels.ndim == 2:
mels = mels.unsqueeze(0)
wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both")
mels, aux = self.upsample(mels.transpose(1, 2))
if batched:
mels = self.fold_with_overlap(mels, target, overlap)
if aux is not None:
aux = self.fold_with_overlap(aux, target, overlap)
b_size, seq_len, _ = mels.size()
h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
x = torch.zeros(b_size, 1).type_as(mels)
if self.args.use_aux_net:
d = self.aux_dims
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
for i in range(seq_len):
m_t = mels[:, i, :]
if self.args.use_aux_net:
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1)
x = self.I(x)
h1 = rnn1(x, h1)
x = x + h1
inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x
h2 = rnn2(inp, h2)
x = x + h2
x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x
x = F.relu(self.fc1(x))
x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x
x = F.relu(self.fc2(x))
logits = self.fc3(x)
if self.args.mode == "mold":
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).type_as(mels)
elif self.args.mode == "gauss":
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).type_as(mels)
elif isinstance(self.args.mode, int):
posterior = F.softmax(logits, dim=1)
distrib = torch.distributions.Categorical(posterior)
sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
output.append(sample)
x = sample.unsqueeze(-1)
else:
raise RuntimeError("Unknown model mode value - ", self.args.mode)
if i % 100 == 0:
self.gen_display(i, seq_len, b_size, start)
output = torch.stack(output).transpose(0, 1)
output = output.cpu()
if batched:
output = output.numpy()
output = output.astype(np.float64)
output = self.xfade_and_unfold(output, target, overlap)
else:
output = output[0]
if self.args.mulaw and isinstance(self.args.mode, int):
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
# Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
output = output[:wave_len]
if wave_len > len(fade_out):
output[-20 * self.config.audio.hop_length :] *= fade_out
self.train()
return output
def gen_display(self, i, seq_len, b_size, start):
gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate
stream(
"%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ",
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
)
def fold_with_overlap(self, x, target, overlap):
"""Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold()
Args:
x (tensor) : Upsampled conditioning features.
shape=(1, timesteps, features)
target (int) : Target timesteps for each index of batch
overlap (int) : Timesteps for both xfade and rnn warmup
Return:
(tensor) : shape=(num_folds, target + 2 * overlap, features)
Details:
x = [[h1, h2, ... hn]]
Where each h is a vector of conditioning features
Eg: target=2, overlap=1 with x.size(1)=10
folded = [[h1, h2, h3, h4],
[h4, h5, h6, h7],
[h7, h8, h9, h10]]
"""
_, total_len, features = x.size()
# Calculate variables needed
num_folds = (total_len - overlap) // (target + overlap)
extended_len = num_folds * (overlap + target) + overlap
remaining = total_len - extended_len
# Pad if some time steps poking out
if remaining != 0:
num_folds += 1
padding = target + 2 * overlap - remaining
x = self.pad_tensor(x, padding, side="after")
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
# Get the values for the folded tensor
for i in range(num_folds):
start = i * (target + overlap)
end = start + target + 2 * overlap
folded[i] = x[:, start:end, :]
return folded
@staticmethod
def get_gru_cell(gru):
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
gru_cell.weight_hh.data = gru.weight_hh_l0.data
gru_cell.weight_ih.data = gru.weight_ih_l0.data
gru_cell.bias_hh.data = gru.bias_hh_l0.data
gru_cell.bias_ih.data = gru.bias_ih_l0.data
return gru_cell
@staticmethod
def pad_tensor(x, pad, side="both"):
# NB - this is just a quick method i need right now
# i.e., it won't generalise to other shapes/dims
b, t, c = x.size()
total = t + 2 * pad if side == "both" else t + pad
padded = torch.zeros(b, total, c).to(x.device)
if side in ("before", "both"):
padded[:, pad : pad + t, :] = x
elif side == "after":
padded[:, :t, :] = x
return padded
@staticmethod
def xfade_and_unfold(y, target, overlap):
"""Applies a crossfade and unfolds into a 1d array.
Args:
y (ndarry) : Batched sequences of audio samples
shape=(num_folds, target + 2 * overlap)
dtype=np.float64
overlap (int) : Timesteps for both xfade and rnn warmup
Return:
(ndarry) : audio samples in a 1d array
shape=(total_len)
dtype=np.float64
Details:
y = [[seq1],
[seq2],
[seq3]]
Apply a gain envelope at both ends of the sequences
y = [[seq1_in, seq1_target, seq1_out],
[seq2_in, seq2_target, seq2_out],
[seq3_in, seq3_target, seq3_out]]
Stagger and add up the groups of samples:
[seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
"""
num_folds, length = y.shape
target = length - 2 * overlap
total_len = num_folds * (target + overlap) + overlap
# Need some silence for the rnn warmup
silence_len = overlap // 2
fade_len = overlap - silence_len
silence = np.zeros((silence_len), dtype=np.float64)
# Equal power crossfade
t = np.linspace(-1, 1, fade_len, dtype=np.float64)
fade_in = np.sqrt(0.5 * (1 + t))
fade_out = np.sqrt(0.5 * (1 - t))
# Concat the silence to the fades
fade_in = np.concatenate([silence, fade_in])
fade_out = np.concatenate([fade_out, silence])
# Apply the gain to the overlap samples
y[:, :overlap] *= fade_in
y[:, -overlap:] *= fade_out
unfolded = np.zeros((total_len), dtype=np.float64)
# Loop to add up all the samples
for i in range(num_folds):
start = i * (target + overlap)
end = start + target + 2 * overlap
unfolded[start:end] += y[i]
return unfolded
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
mels = batch["input"]
waveform = batch["waveform"]
waveform_coarse = batch["waveform_coarse"]
y_hat = self.forward(waveform, mels)
if isinstance(self.args.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
waveform_coarse = waveform_coarse.float()
waveform_coarse = waveform_coarse.unsqueeze(-1)
# compute losses
loss_dict = criterion(y_hat, waveform_coarse)
return {"model_output": y_hat}, loss_dict
def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)
@torch.no_grad()
def test(
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, Dict]:
ap = self.ap
figures = {}
audios = {}
samples = test_loader.dataset.load_test_samples(1)
for idx, sample in enumerate(samples):
x = torch.FloatTensor(sample[0])
x = x.to(next(self.parameters()).device)
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
x_hat = ap.melspectrogram(y_hat)
figures.update(
{
f"test_{idx}/ground_truth": plot_spectrogram(x.T),
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
}
)
audios.update({f"test_{idx}/audio": y_hat})
# audios.update({f"real_{idx}/audio": y_hat})
return figures, audios
def test_log(
self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
figures, audios = outputs
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@staticmethod
def format_batch(batch: Dict) -> Dict:
waveform = batch[0]
mels = batch[1]
waveform_coarse = batch[2]
return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse}
def get_data_loader( # pylint: disable=no-self-use
self,
config: Coqpit,
assets: Dict,
is_eval: True,
samples: List,
verbose: bool,
num_gpus: int,
):
ap = self.ap
dataset = WaveRNNDataset(
ap=ap,
items=samples,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad=config.model_args.pad,
mode=config.model_args.mode,
mulaw=config.model_args.mulaw,
is_training=not is_eval,
verbose=verbose,
)
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=1 if is_eval else config.batch_size,
shuffle=num_gpus == 0,
collate_fn=dataset.collate,
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=True,
)
return loader
def get_criterion(self):
# define train functions
return WaveRNNLoss(self.args.mode)
@staticmethod
def init_from_config(config: "WavernnConfig"):
return Wavernn(config)