Add files via upload
This commit is contained in:
@@ -0,0 +1,154 @@
|
||||
import importlib
|
||||
import re
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(config: Coqpit):
|
||||
"""Load models directly from configuration."""
|
||||
if "discriminator_model" in config and "generator_model" in config:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.gan")
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower())
|
||||
if config.model.lower() == "wavernn":
|
||||
MyModel = getattr(MyModel, "Wavernn")
|
||||
elif config.model.lower() == "gan":
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
elif config.model.lower() == "wavegrad":
|
||||
MyModel = getattr(MyModel, "Wavegrad")
|
||||
else:
|
||||
try:
|
||||
MyModel = getattr(MyModel, to_camel(config.model))
|
||||
except ModuleNotFoundError as e:
|
||||
raise ValueError(f"Model {config.model} not exist!") from e
|
||||
print(" > Vocoder Model: {}".format(config.model))
|
||||
return MyModel.init_from_config(config)
|
||||
|
||||
|
||||
def setup_generator(c):
|
||||
"""TODO: use config object as arguments"""
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
# this is to preserve the Wavernn class name (instead of Wavernn)
|
||||
if c.generator_model.lower() in "hifigan_generator":
|
||||
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params)
|
||||
elif c.generator_model.lower() in "melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model in "melgan_fb_generator":
|
||||
raise ValueError("melgan_fb_generator is now fullband_melgan_generator")
|
||||
elif c.generator_model.lower() in "multiband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "fullband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "parallel_wavegan_generator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
stacks=c.generator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=c.audio["num_mels"],
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
)
|
||||
elif c.generator_model.lower() in "univnet_generator":
|
||||
model = MyModel(**c.generator_model_params)
|
||||
else:
|
||||
raise NotImplementedError(f"Model {c.generator_model} not implemented!")
|
||||
return model
|
||||
|
||||
|
||||
def setup_discriminator(c):
|
||||
"""TODO: use config objekt as arguments"""
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||
if c.discriminator_model in "hifigan_discriminator":
|
||||
model = MyModel()
|
||||
if c.discriminator_model in "random_window_discriminator":
|
||||
model = MyModel(
|
||||
cond_channels=c.audio["num_mels"],
|
||||
hop_length=c.audio["hop_length"],
|
||||
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"],
|
||||
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"],
|
||||
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"],
|
||||
window_sizes=c.discriminator_model_params["window_sizes"],
|
||||
)
|
||||
if c.discriminator_model in "melgan_multiscale_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=c.discriminator_model_params["base_channels"],
|
||||
max_channels=c.discriminator_model_params["max_channels"],
|
||||
downsample_factors=c.discriminator_model_params["downsample_factors"],
|
||||
)
|
||||
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
stacks=c.discriminator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
)
|
||||
if c.discriminator_model == "parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
)
|
||||
if c.discriminator_model == "univnet_discriminator":
|
||||
model = MyModel()
|
||||
return model
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,55 @@
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.model import BaseTrainerModel
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseVocoder(BaseTrainerModel):
|
||||
"""Base `vocoder` class. Every new `vocoder` model must inherit this.
|
||||
|
||||
It defines `vocoder` specific functions on top of `Model`.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
MODEL_TYPE = "vocoder"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
||||
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
if "characters" in config:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.num_chars = num_chars
|
||||
if hasattr(self.config, "model_args"):
|
||||
config.model_args.num_chars = num_chars
|
||||
if "model_args" in config:
|
||||
self.args = self.config.model_args
|
||||
# This is for backward compatibility
|
||||
if "model_params" in config:
|
||||
self.args = self.config.model_params
|
||||
else:
|
||||
self.config = config
|
||||
if "model_args" in config:
|
||||
self.args = self.config.model_args
|
||||
# This is for backward compatibility
|
||||
if "model_params" in config:
|
||||
self.args = self.config.model_params
|
||||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
||||
|
||||
class FullbandMelganGenerator(MelganGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=(2, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=4,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
proj_kernel=proj_kernel,
|
||||
base_channels=base_channels,
|
||||
upsample_factors=upsample_factors,
|
||||
res_kernel=res_kernel,
|
||||
num_res_blocks=num_res_blocks,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
return self.layers(cond_features)
|
||||
@@ -0,0 +1,374 @@
|
||||
from inspect import signature
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from TTS.vocoder.models import setup_discriminator, setup_generator
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
class GAN(BaseVocoder):
|
||||
def __init__(self, config: Coqpit, ap: AudioProcessor = None):
|
||||
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
|
||||
It also helps mixing and matching different generator and disciminator networks easily.
|
||||
|
||||
To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest
|
||||
is handled by the `GAN` class.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.
|
||||
|
||||
Examples:
|
||||
Initializing the GAN model with HifiGAN generator and discriminator.
|
||||
>>> from TTS.vocoder.configs import HifiganConfig
|
||||
>>> config = HifiganConfig()
|
||||
>>> model = GAN(config)
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.model_g = setup_generator(config)
|
||||
self.model_d = setup_discriminator(config)
|
||||
self.train_disc = False # if False, train only the generator.
|
||||
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
|
||||
self.ap = ap
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the generator's forward pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: output of the GAN generator network.
|
||||
"""
|
||||
return self.model_g.forward(x)
|
||||
|
||||
def inference(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the generator's inference pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
Returns:
|
||||
torch.Tensor: output of the GAN generator network.
|
||||
"""
|
||||
return self.model_g.inference(x)
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
|
||||
network on the current pass.
|
||||
|
||||
Args:
|
||||
batch (Dict): Batch of samples returned by the dataloader.
|
||||
criterion (Dict): Criterion used to compute the losses.
|
||||
optimizer_idx (int): ID of the optimizer in use on the current pass.
|
||||
|
||||
Raises:
|
||||
ValueError: `optimizer_idx` is an unexpected value.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: model outputs and the computed loss values.
|
||||
"""
|
||||
outputs = {}
|
||||
loss_dict = {}
|
||||
|
||||
x = batch["input"]
|
||||
y = batch["waveform"]
|
||||
|
||||
if optimizer_idx not in [0, 1]:
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# DISCRIMINATOR optimization
|
||||
|
||||
# generator pass
|
||||
y_hat = self.model_g(x)[:, :, : y.size(2)]
|
||||
|
||||
# cache for generator loss
|
||||
# pylint: disable=W0201
|
||||
self.y_hat_g = y_hat
|
||||
self.y_hat_sub = None
|
||||
self.y_sub_g = None
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
self.y_hat_sub = y_hat
|
||||
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
||||
self.y_hat_g = y_hat # save for generator loss
|
||||
self.y_sub_g = self.model_g.pqmf_analysis(y)
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
|
||||
if self.train_disc:
|
||||
# use different samples for G and D trainings
|
||||
if self.config.diff_samples_for_G_and_D:
|
||||
x_d = batch["input_disc"]
|
||||
y_d = batch["waveform_disc"]
|
||||
# use a different sample than generator
|
||||
with torch.no_grad():
|
||||
y_hat = self.model_g(x_d)
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = self.model_g.pqmf_synthesis(y_hat)
|
||||
else:
|
||||
# use the same samples as generator
|
||||
x_d = x.clone()
|
||||
y_d = y.clone()
|
||||
y_hat = self.y_hat_g
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(self.model_d.forward).parameters) == 2:
|
||||
D_out_fake = self.model_d(y_hat.detach().clone(), x_d)
|
||||
D_out_real = self.model_d(y_d, x_d)
|
||||
else:
|
||||
D_out_fake = self.model_d(y_hat.detach())
|
||||
D_out_real = self.model_d(y_d)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
# self.model_d returns scores and features
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
# model D returns only scores
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
|
||||
outputs = {"model_outputs": y_hat}
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# GENERATOR loss
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if self.train_disc:
|
||||
if len(signature(self.model_d.forward).parameters) == 2:
|
||||
D_out_fake = self.model_d(self.y_hat_g, x)
|
||||
else:
|
||||
D_out_fake = self.model_d(self.y_hat_g)
|
||||
D_out_real = None
|
||||
|
||||
if self.config.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = self.model_d(y)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
feats_fake, feats_real = None, None
|
||||
|
||||
# compute losses
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
|
||||
)
|
||||
outputs = {"model_outputs": self.y_hat_g}
|
||||
return outputs, loss_dict
|
||||
|
||||
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Logging shared by the training and evaluation.
|
||||
|
||||
Args:
|
||||
name (str): Name of the run. `train` or `eval`,
|
||||
ap (AudioProcessor): Audio processor used in training.
|
||||
batch (Dict): Batch used in the last train/eval step.
|
||||
outputs (Dict): Model outputs from the last train/eval step.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: log figures and audio samples.
|
||||
"""
|
||||
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
|
||||
y = batch["waveform"]
|
||||
figures = plot_results(y_hat, y, ap, name)
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
audios = {f"{name}/audio": sample_voice}
|
||||
return figures, audios
|
||||
|
||||
def train_log(
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for training."""
|
||||
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Call `train_step()` with `no_grad()`"""
|
||||
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for evaluation."""
|
||||
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config: Coqpit,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
|
||||
cache: bool = False,
|
||||
) -> None:
|
||||
"""Load a GAN checkpoint and initialize model parameters.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model config.
|
||||
checkpoint_path (str): Checkpoint file path.
|
||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||
"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
# band-aid for older than v0.0.15 GAN models
|
||||
if "model_disc" in state:
|
||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||
else:
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.model_d = None
|
||||
if hasattr(self.model_g, "remove_weight_norm"):
|
||||
self.model_g.remove_weight_norm()
|
||||
|
||||
def on_train_step_start(self, trainer) -> None:
|
||||
"""Enable the discriminator training based on `steps_to_start_discriminator`
|
||||
|
||||
Args:
|
||||
trainer (Trainer): Trainer object.
|
||||
"""
|
||||
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||
|
||||
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
||||
|
||||
Returns:
|
||||
List: optimizers.
|
||||
"""
|
||||
optimizer1 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
|
||||
)
|
||||
optimizer2 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
|
||||
)
|
||||
return [optimizer2, optimizer1]
|
||||
|
||||
def get_lr(self) -> List:
|
||||
"""Set the initial learning rates for each optimizer.
|
||||
|
||||
Returns:
|
||||
List: learning rates for each optimizer.
|
||||
"""
|
||||
return [self.config.lr_disc, self.config.lr_gen]
|
||||
|
||||
def get_scheduler(self, optimizer) -> List:
|
||||
"""Set the schedulers for each optimizer.
|
||||
|
||||
Args:
|
||||
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
|
||||
|
||||
Returns:
|
||||
List: Schedulers, one for each optimizer.
|
||||
"""
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler2, scheduler1]
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: List) -> Dict:
|
||||
"""Format the batch for training.
|
||||
|
||||
Args:
|
||||
batch (List): Batch out of the dataloader.
|
||||
|
||||
Returns:
|
||||
Dict: formatted model inputs.
|
||||
"""
|
||||
if isinstance(batch[0], list):
|
||||
x_G, y_G = batch[0]
|
||||
x_D, y_D = batch[1]
|
||||
return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D}
|
||||
x, y = batch
|
||||
return {"input": x, "waveform": y}
|
||||
|
||||
def get_data_loader( # pylint: disable=no-self-use, unused-argument
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: True,
|
||||
samples: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None, # pylint: disable=unused-argument
|
||||
):
|
||||
"""Initiate and return the GAN dataloader.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model config.
|
||||
ap (AudioProcessor): Audio processor.
|
||||
is_eval (True): Set the dataloader for evaluation if true.
|
||||
samples (List): Data samples.
|
||||
verbose (bool): Log information if true.
|
||||
num_gpus (int): Number of GPUs in use.
|
||||
rank (int): Rank of the current GPU. Defaults to None.
|
||||
|
||||
Returns:
|
||||
DataLoader: Torch dataloader.
|
||||
"""
|
||||
dataset = GANDataset(
|
||||
ap=self.ap,
|
||||
items=samples,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=self.ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||
is_training=not is_eval,
|
||||
return_segments=not is_eval,
|
||||
use_noise_augment=config.use_noise_augment,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_eval else config.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_criterion(self):
|
||||
"""Return criterions for the optimizers"""
|
||||
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
|
||||
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
||||
return GAN(config, ap=ap)
|
||||
@@ -0,0 +1,217 @@
|
||||
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
"""HiFiGAN Periodic Discriminator
|
||||
|
||||
Takes every Pth value from the input waveform and applied a stack of convoluations.
|
||||
|
||||
Note:
|
||||
if `period` is 2
|
||||
`waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat`
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[Tensor]: discriminator scores per sample in the batch.
|
||||
[List[Tensor]]: list of features from each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
x: [B, 1, T]
|
||||
"""
|
||||
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
get_padding = lambda k, d: int((k * d - d) / 2)
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[Tensor]: discriminator scores per sample in the batch.
|
||||
[List[Tensor]]: list of features from each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
x: [B, 1, T]
|
||||
"""
|
||||
feat = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
feat.append(x)
|
||||
x = self.conv_post(x)
|
||||
feat.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, feat
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
"""HiFiGAN Multi-Period Discriminator (MPD)
|
||||
Wrapper for the `PeriodDiscriminator` to apply it in different periods.
|
||||
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
|
||||
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
[List[Tensor]]: list of scores from each discriminator.
|
||||
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
|
||||
|
||||
Shapes:
|
||||
x: [B, 1, T]
|
||||
"""
|
||||
scores = []
|
||||
feats = []
|
||||
for _, d in enumerate(self.discriminators):
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
return scores, feats
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
"""HiFiGAN Scale Discriminator.
|
||||
It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper.
|
||||
|
||||
Args:
|
||||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
Tensor: discriminator scores.
|
||||
List[Tensor]: list of features from the convolutiona layers.
|
||||
"""
|
||||
feat = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
feat.append(x)
|
||||
x = self.conv_post(x)
|
||||
feat.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
return x, feat
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(torch.nn.Module):
|
||||
"""HiFiGAN Multi-Scale Discriminator.
|
||||
It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
]
|
||||
)
|
||||
self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)])
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores = []
|
||||
feats = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
x = self.meanpools[i - 1](x)
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
return scores, feats
|
||||
|
||||
|
||||
class HifiganDiscriminator(nn.Module):
|
||||
"""HiFiGAN discriminator wrapping MPD and MSD."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.msd = MultiScaleDiscriminator()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores, feats = self.mpd(x)
|
||||
scores_, feats_ = self.msd(x)
|
||||
return scores + scores_, feats + feats_
|
||||
@@ -0,0 +1,301 @@
|
||||
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def get_padding(k, d):
|
||||
return int((k * d - d) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|
||||
|--------------------------------------------------------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input tensor.
|
||||
Returns:
|
||||
Tensor: output tensor.
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
"""
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|
||||
|---------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super().__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class HifiganGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
resblock_type,
|
||||
resblock_dilation_sizes,
|
||||
resblock_kernel_sizes,
|
||||
upsample_kernel_sizes,
|
||||
upsample_initial_channel,
|
||||
upsample_factors,
|
||||
inference_padding=5,
|
||||
cond_channels=0,
|
||||
conv_pre_weight_norm=True,
|
||||
conv_post_weight_norm=True,
|
||||
conv_post_bias=True,
|
||||
):
|
||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||
|
||||
Network:
|
||||
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
|
||||
.. -> zI ---|
|
||||
resblockN_kNx1 -> zN ---'
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
resblock_type (str): type of the `ResBlock`. '1' or '2'.
|
||||
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
|
||||
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
|
||||
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
|
||||
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
|
||||
for each consecutive upsampling layer.
|
||||
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
|
||||
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inference_padding = inference_padding
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_factors)
|
||||
# initial upsampling layers
|
||||
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||
# upsampling layers
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
# MRF blocks
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
# post convolution layer
|
||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||
if cond_channels > 0:
|
||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
if not conv_pre_weight_norm:
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
|
||||
if not conv_post_weight_norm:
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor.
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
o = self.conv_pre(x)
|
||||
if hasattr(self, "cond_layer"):
|
||||
o = o + self.cond_layer(g)
|
||||
for i in range(self.num_upsamples):
|
||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||
o = self.ups[i](o)
|
||||
z_sum = None
|
||||
for j in range(self.num_kernels):
|
||||
if z_sum is None:
|
||||
z_sum = self.resblocks[i * self.num_kernels + j](o)
|
||||
else:
|
||||
z_sum += self.resblocks[i * self.num_kernels + j](o)
|
||||
o = z_sum / self.num_kernels
|
||||
o = F.leaky_relu(o)
|
||||
o = self.conv_post(o)
|
||||
o = torch.tanh(o)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
@@ -0,0 +1,84 @@
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
|
||||
class MelganDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4, 4),
|
||||
groups_denominator=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
layer_kernel_size = np.prod(kernel_sizes)
|
||||
layer_padding = (layer_kernel_size - 1) // 2
|
||||
|
||||
# initial layer
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
]
|
||||
|
||||
# downsampling layers
|
||||
layer_in_channels = base_channels
|
||||
for downsample_factor in downsample_factors:
|
||||
layer_out_channels = min(layer_in_channels * downsample_factor, max_channels)
|
||||
layer_kernel_size = downsample_factor * 10 + 1
|
||||
layer_padding = (layer_kernel_size - 1) // 2
|
||||
layer_groups = layer_in_channels // groups_denominator
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
layer_in_channels,
|
||||
layer_out_channels,
|
||||
kernel_size=layer_kernel_size,
|
||||
stride=downsample_factor,
|
||||
padding=layer_padding,
|
||||
groups=layer_groups,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
]
|
||||
layer_in_channels = layer_out_channels
|
||||
|
||||
# last 2 layers
|
||||
layer_padding1 = (kernel_sizes[0] - 1) // 2
|
||||
layer_padding2 = (kernel_sizes[1] - 1) // 2
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
layer_out_channels,
|
||||
layer_out_channels,
|
||||
kernel_size=kernel_sizes[0],
|
||||
stride=1,
|
||||
padding=layer_padding1,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv1d(
|
||||
layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
feats = []
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
feats.append(x)
|
||||
return x, feats
|
||||
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.melgan import ResidualStack
|
||||
|
||||
|
||||
class MelganGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=(8, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=3,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# assert model parameters
|
||||
assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number."
|
||||
|
||||
# setup additional model parameters
|
||||
base_padding = (proj_kernel - 1) // 2
|
||||
act_slope = 0.2
|
||||
self.inference_padding = 2
|
||||
|
||||
# initial layer
|
||||
layers = []
|
||||
layers += [
|
||||
nn.ReflectionPad1d(base_padding),
|
||||
weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)),
|
||||
]
|
||||
|
||||
# upsampling layers and residual stacks
|
||||
for idx, upsample_factor in enumerate(upsample_factors):
|
||||
layer_in_channels = base_channels // (2**idx)
|
||||
layer_out_channels = base_channels // (2 ** (idx + 1))
|
||||
layer_filter_size = upsample_factor * 2
|
||||
layer_stride = upsample_factor
|
||||
layer_output_padding = upsample_factor % 2
|
||||
layer_padding = upsample_factor // 2 + layer_output_padding
|
||||
layers += [
|
||||
nn.LeakyReLU(act_slope),
|
||||
weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
layer_in_channels,
|
||||
layer_out_channels,
|
||||
layer_filter_size,
|
||||
stride=layer_stride,
|
||||
padding=layer_padding,
|
||||
output_padding=layer_output_padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel),
|
||||
]
|
||||
|
||||
layers += [nn.LeakyReLU(act_slope)]
|
||||
|
||||
# final layer
|
||||
layers += [
|
||||
nn.ReflectionPad1d(base_padding),
|
||||
weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)),
|
||||
nn.Tanh(),
|
||||
]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, c):
|
||||
return self.layers(c)
|
||||
|
||||
def inference(self, c):
|
||||
c = c.to(self.layers[1].weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.layers(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for _, layer in enumerate(self.layers):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.parametrize.remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
@@ -0,0 +1,50 @@
|
||||
from torch import nn
|
||||
|
||||
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
|
||||
|
||||
|
||||
class MelganMultiscaleDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
num_scales=3,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4),
|
||||
pooling_kernel_size=4,
|
||||
pooling_stride=2,
|
||||
pooling_padding=2,
|
||||
groups_denominator=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
MelganDiscriminator(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_sizes=kernel_sizes,
|
||||
base_channels=base_channels,
|
||||
max_channels=max_channels,
|
||||
downsample_factors=downsample_factors,
|
||||
groups_denominator=groups_denominator,
|
||||
)
|
||||
for _ in range(num_scales)
|
||||
]
|
||||
)
|
||||
|
||||
self.pooling = nn.AvgPool1d(
|
||||
kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores = []
|
||||
feats = []
|
||||
for disc in self.discriminators:
|
||||
score, feat = disc(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
x = self.pooling(x)
|
||||
return scores, feats
|
||||
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.layers.pqmf import PQMF
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
||||
|
||||
class MultibandMelganGenerator(MelganGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=(2, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=3,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
proj_kernel=proj_kernel,
|
||||
base_channels=base_channels,
|
||||
upsample_factors=upsample_factors,
|
||||
res_kernel=res_kernel,
|
||||
num_res_blocks=num_res_blocks,
|
||||
)
|
||||
self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
|
||||
|
||||
def pqmf_analysis(self, x):
|
||||
return self.pqmf_layer.analysis(x)
|
||||
|
||||
def pqmf_synthesis(self, x):
|
||||
return self.pqmf_layer.synthesis(x)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
return self.pqmf_synthesis(self.layers(cond_features))
|
||||
@@ -0,0 +1,187 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
|
||||
|
||||
class ParallelWaveganDiscriminator(nn.Module):
|
||||
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
|
||||
It classifies each audio window real/fake and returns a sequence
|
||||
of predictions.
|
||||
It is a stack of convolutional blocks with dilation.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=10,
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size."
|
||||
assert dilation_factor > 0, " [!] dilation factor must be > 0."
|
||||
self.conv_layers = nn.ModuleList()
|
||||
conv_in_channels = in_channels
|
||||
for i in range(num_layers - 1):
|
||||
if i == 0:
|
||||
dilation = 1
|
||||
else:
|
||||
dilation = i if dilation_factor == 1 else dilation_factor**i
|
||||
conv_in_channels = conv_channels
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
conv_layer = [
|
||||
nn.Conv1d(
|
||||
conv_in_channels,
|
||||
conv_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
]
|
||||
self.conv_layers += conv_layer
|
||||
padding = (kernel_size - 1) // 2
|
||||
last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
||||
self.conv_layers += [last_conv_layer]
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x : (B, 1, T).
|
||||
Returns:
|
||||
Tensor: (B, 1, T)
|
||||
"""
|
||||
for f in self.conv_layers:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
class ResidualParallelWaveganDiscriminator(nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
):
|
||||
super().__init__()
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_layers = num_layers
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.res_factor = math.sqrt(1.0 / num_layers)
|
||||
|
||||
# check the number of num_layers and stacks
|
||||
assert num_layers % stacks == 0
|
||||
layers_per_stack = num_layers // stacks
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = nn.Sequential(
|
||||
nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for layer in range(num_layers):
|
||||
dilation = 2 ** (layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
res_channels=res_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=-1,
|
||||
dilation=dilation,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
use_causal_conv=False,
|
||||
)
|
||||
self.conv_layers += [conv]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = nn.ModuleList(
|
||||
[
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
||||
nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (B, 1, T).
|
||||
"""
|
||||
x = self.first_conv(x)
|
||||
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, None)
|
||||
skips += h
|
||||
skips *= self.res_factor
|
||||
|
||||
# apply final layers
|
||||
x = skips
|
||||
for f in self.last_conv_layers:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
print(f"Weight norm is removed from {m}.")
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
@@ -0,0 +1,164 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||
|
||||
|
||||
class ParallelWaveganGenerator(torch.nn.Module):
|
||||
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
|
||||
It is similar to WaveNet with no causal convolution.
|
||||
It is conditioned on an aux feature (spectrogram) to generate
|
||||
an output waveform from an input noise.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=[4, 4, 4, 4],
|
||||
inference_padding=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.aux_channels = aux_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.upsample_factors = upsample_factors
|
||||
self.upsample_scale = np.prod(upsample_factors)
|
||||
self.inference_padding = inference_padding
|
||||
self.use_weight_norm = use_weight_norm
|
||||
|
||||
# check the number of layers and stacks
|
||||
assert num_res_blocks % stacks == 0
|
||||
layers_per_stack = num_res_blocks // stacks
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True)
|
||||
|
||||
# define conv + upsampling network
|
||||
self.upsample_net = ConvUpsample(upsample_factors=upsample_factors)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
for layer in range(num_res_blocks):
|
||||
dilation = 2 ** (layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
res_channels=res_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=aux_channels,
|
||||
dilation=dilation,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
)
|
||||
self.conv_layers += [conv]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c: (B, C ,T').
|
||||
o: Output tensor (B, out_channels, T)
|
||||
"""
|
||||
# random noise
|
||||
x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
|
||||
# perform upsampling
|
||||
if c is not None and self.upsample_net is not None:
|
||||
c = self.upsample_net(c)
|
||||
assert (
|
||||
c.shape[-1] == x.shape[-1]
|
||||
), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"
|
||||
|
||||
# encode to hidden representation
|
||||
x = self.first_conv(x)
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, c)
|
||||
skips += h
|
||||
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
||||
|
||||
# apply final layers
|
||||
x = skips
|
||||
for f in self.last_conv_layers:
|
||||
x = f(x)
|
||||
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
c = c.to(self.first_conv.weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
if self.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
@@ -0,0 +1,203 @@
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, cond_channels, downsample_factor):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.downsample_factor = downsample_factor
|
||||
|
||||
self.start = nn.Sequential(
|
||||
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1),
|
||||
)
|
||||
self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1)
|
||||
self.end = nn.Sequential(
|
||||
nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2)
|
||||
)
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv1d(in_channels, in_channels * 2, kernel_size=1),
|
||||
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
|
||||
)
|
||||
|
||||
def forward(self, inputs, conditions):
|
||||
outputs = self.start(inputs) + self.lc_conv1d(conditions)
|
||||
outputs = self.end(outputs)
|
||||
residual_outputs = self.residual(inputs)
|
||||
outputs = outputs + residual_outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, downsample_factor):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_factor = downsample_factor
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor)
|
||||
self.layers = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2),
|
||||
)
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.downsample_factor > 1:
|
||||
outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs))
|
||||
else:
|
||||
outputs = self.layers(inputs) + self.residual(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class ConditionalDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)):
|
||||
super().__init__()
|
||||
|
||||
assert len(downsample_factors) == len(out_channels) + 1
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.downsample_factors = downsample_factors
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.pre_cond_layers = nn.ModuleList()
|
||||
self.post_cond_layers = nn.ModuleList()
|
||||
|
||||
# layers before condition features
|
||||
self.pre_cond_layers += [DBlock(in_channels, 64, 1)]
|
||||
in_channels = 64
|
||||
for i, channel in enumerate(out_channels):
|
||||
self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i]))
|
||||
in_channels = channel
|
||||
|
||||
# condition block
|
||||
self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1])
|
||||
|
||||
# layers after condition block
|
||||
self.post_cond_layers += [
|
||||
DBlock(in_channels * 2, in_channels * 2, 1),
|
||||
DBlock(in_channels * 2, in_channels * 2, 1),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
nn.Conv1d(in_channels * 2, 1, kernel_size=1),
|
||||
]
|
||||
|
||||
def forward(self, inputs, conditions):
|
||||
batch_size = inputs.size()[0]
|
||||
outputs = inputs.view(batch_size, self.in_channels, -1)
|
||||
for layer in self.pre_cond_layers:
|
||||
outputs = layer(outputs)
|
||||
outputs = self.cond_block(outputs, conditions)
|
||||
for layer in self.post_cond_layers:
|
||||
outputs = layer(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class UnconditionalDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)):
|
||||
super().__init__()
|
||||
|
||||
self.downsample_factors = downsample_factors
|
||||
self.in_channels = in_channels
|
||||
self.downsample_factors = downsample_factors
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers += [DBlock(self.in_channels, base_channels, 1)]
|
||||
in_channels = base_channels
|
||||
for i, factor in enumerate(downsample_factors):
|
||||
self.layers.append(DBlock(in_channels, out_channels[i], factor))
|
||||
in_channels *= 2
|
||||
self.layers += [
|
||||
DBlock(in_channels, in_channels, 1),
|
||||
DBlock(in_channels, in_channels, 1),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
nn.Conv1d(in_channels, 1, kernel_size=1),
|
||||
]
|
||||
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.size()[0]
|
||||
outputs = inputs.view(batch_size, self.in_channels, -1)
|
||||
for layer in self.layers:
|
||||
outputs = layer(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class RandomWindowDiscriminator(nn.Module):
|
||||
"""Random Window Discriminator as described in
|
||||
http://arxiv.org/abs/1909.11646"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cond_channels,
|
||||
hop_length,
|
||||
uncond_disc_donwsample_factors=(8, 4),
|
||||
cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)),
|
||||
cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)),
|
||||
window_sizes=(512, 1024, 2048, 4096, 8192),
|
||||
):
|
||||
super().__init__()
|
||||
self.cond_channels = cond_channels
|
||||
self.window_sizes = window_sizes
|
||||
self.hop_length = hop_length
|
||||
self.base_window_size = self.hop_length * 2
|
||||
self.ks = [ws // self.base_window_size for ws in window_sizes]
|
||||
|
||||
# check arguments
|
||||
assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes)
|
||||
for ws in window_sizes:
|
||||
assert ws % hop_length == 0
|
||||
|
||||
for idx, cf in enumerate(cond_disc_downsample_factors):
|
||||
assert np.prod(cf) == hop_length // self.ks[idx]
|
||||
|
||||
# define layers
|
||||
self.unconditional_discriminators = nn.ModuleList([])
|
||||
for k in self.ks:
|
||||
layer = UnconditionalDiscriminator(
|
||||
in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors
|
||||
)
|
||||
self.unconditional_discriminators.append(layer)
|
||||
|
||||
self.conditional_discriminators = nn.ModuleList([])
|
||||
for idx, k in enumerate(self.ks):
|
||||
layer = ConditionalDiscriminator(
|
||||
in_channels=k,
|
||||
cond_channels=cond_channels,
|
||||
downsample_factors=cond_disc_downsample_factors[idx],
|
||||
out_channels=cond_disc_out_channels[idx],
|
||||
)
|
||||
self.conditional_discriminators.append(layer)
|
||||
|
||||
def forward(self, x, c):
|
||||
scores = []
|
||||
feats = []
|
||||
# unconditional pass
|
||||
for window_size, layer in zip(self.window_sizes, self.unconditional_discriminators):
|
||||
index = np.random.randint(x.shape[-1] - window_size)
|
||||
|
||||
score = layer(x[:, :, index : index + window_size])
|
||||
scores.append(score)
|
||||
|
||||
# conditional pass
|
||||
for window_size, layer in zip(self.window_sizes, self.conditional_discriminators):
|
||||
frame_size = window_size // self.hop_length
|
||||
lc_index = np.random.randint(c.shape[-1] - frame_size)
|
||||
sample_index = lc_index * self.hop_length
|
||||
x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length]
|
||||
c_sub = c[:, :, lc_index : lc_index + frame_size]
|
||||
|
||||
score = layer(x_sub, c_sub)
|
||||
scores.append(score)
|
||||
return scores, feats
|
||||
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class SpecDiscriminator(nn.Module):
|
||||
"""docstring for Discriminator."""
|
||||
|
||||
def __init__(self, fft_size=1024, hop_length=120, win_length=600, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.fft_size = fft_size
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.stft = TorchSTFT(fft_size, hop_length, win_length)
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
|
||||
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, y):
|
||||
fmap = []
|
||||
with torch.no_grad():
|
||||
y = y.squeeze(1)
|
||||
y = self.stft(y)
|
||||
y = y.unsqueeze(1)
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y = d(y)
|
||||
y = F.leaky_relu(y, LRELU_SLOPE)
|
||||
fmap.append(y)
|
||||
|
||||
y = self.out(y)
|
||||
fmap.append(y)
|
||||
|
||||
return torch.flatten(y, 1, -1), fmap
|
||||
|
||||
|
||||
class MultiResSpecDiscriminator(torch.nn.Module):
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann_window"
|
||||
):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
||||
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
||||
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores = []
|
||||
feats = []
|
||||
for d in self.discriminators:
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
|
||||
return scores, feats
|
||||
|
||||
|
||||
class UnivnetDiscriminator(nn.Module):
|
||||
"""Univnet discriminator wrapping MPD and MSD."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.msd = MultiResSpecDiscriminator()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores, feats = self.mpd(x)
|
||||
scores_, feats_ = self.msd(x)
|
||||
return scores + scores_, feats + feats_
|
||||
@@ -0,0 +1,157 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class UnivnetGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int,
|
||||
cond_channels: int,
|
||||
upsample_factors: List[int],
|
||||
lvc_layers_each_block: int,
|
||||
lvc_kernel_size: int,
|
||||
kpnet_hidden_channels: int,
|
||||
kpnet_conv_size: int,
|
||||
dropout: float,
|
||||
use_weight_norm=True,
|
||||
):
|
||||
"""Univnet Generator network.
|
||||
|
||||
Paper: https://arxiv.org/pdf/2106.07889.pdf
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input tensor channels.
|
||||
out_channels (int): Number of channels of the output tensor.
|
||||
hidden_channels (int): Number of hidden network channels.
|
||||
cond_channels (int): Number of channels of the conditioning tensors.
|
||||
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
|
||||
lvc_layers_each_block (int): Number of LVC layers in each block.
|
||||
lvc_kernel_size (int): Kernel size of the LVC layers.
|
||||
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
|
||||
kpnet_conv_size (int): Number of convolution channels in the key-point network.
|
||||
dropout (float): Dropout rate.
|
||||
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.upsample_scale = np.prod(upsample_factors)
|
||||
self.lvc_block_nums = len(upsample_factors)
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = torch.nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
|
||||
)
|
||||
|
||||
# define residual blocks
|
||||
self.lvc_blocks = torch.nn.ModuleList()
|
||||
cond_hop_length = 1
|
||||
for n in range(self.lvc_block_nums):
|
||||
cond_hop_length = cond_hop_length * upsample_factors[n]
|
||||
lvcb = LVCBlock(
|
||||
in_channels=hidden_channels,
|
||||
cond_channels=cond_channels,
|
||||
upsample_ratio=upsample_factors[n],
|
||||
conv_layers=lvc_layers_each_block,
|
||||
conv_kernel_size=lvc_kernel_size,
|
||||
cond_hop_length=cond_hop_length,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=dropout,
|
||||
)
|
||||
self.lvc_blocks += [lvcb]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Conv1d(
|
||||
hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, c):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
c (Tensor): Local conditioning auxiliary features (B, C ,T').
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T)
|
||||
"""
|
||||
# random noise
|
||||
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
x = self.first_conv(x)
|
||||
|
||||
for n in range(self.lvc_block_nums):
|
||||
x = self.lvc_blocks[n](x, c)
|
||||
|
||||
# apply final layers
|
||||
for f in self.last_conv_layers:
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = f(x)
|
||||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization module from all of the layers."""
|
||||
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
parametrize.remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
"""Apply weight normalization module from all of the layers."""
|
||||
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
"""Return receptive field size."""
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""Perform inference.
|
||||
Args:
|
||||
c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`.
|
||||
Returns:
|
||||
Tensor: Output tensor (T, out_channels)
|
||||
"""
|
||||
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
|
||||
c = c.to(next(self.parameters()))
|
||||
return self.forward(c)
|
||||
@@ -0,0 +1,345 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavegradArgs(Coqpit):
|
||||
in_channels: int = 80
|
||||
out_channels: int = 1
|
||||
use_weight_norm: bool = False
|
||||
y_conv_channels: int = 32
|
||||
x_conv_channels: int = 768
|
||||
dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512])
|
||||
ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
|
||||
upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
|
||||
upsample_dilations: List[List[int]] = field(
|
||||
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]]
|
||||
)
|
||||
|
||||
|
||||
class Wavegrad(BaseVocoder):
|
||||
"""🐸 🌊 WaveGrad 🌊 model.
|
||||
Paper - https://arxiv.org/abs/2009.00713
|
||||
|
||||
Examples:
|
||||
Initializing the model.
|
||||
|
||||
>>> from TTS.vocoder.configs import WavegradConfig
|
||||
>>> config = WavegradConfig()
|
||||
>>> model = Wavegrad(config)
|
||||
|
||||
Paper Abstract:
|
||||
This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the
|
||||
data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts
|
||||
from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned
|
||||
on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting
|
||||
the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in
|
||||
terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations.
|
||||
Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive
|
||||
baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations.
|
||||
Audio samples are available at this https URL.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.use_weight_norm = config.model_params.use_weight_norm
|
||||
self.hop_len = np.prod(config.model_params.upsample_factors)
|
||||
self.noise_level = None
|
||||
self.num_steps = None
|
||||
self.beta = None
|
||||
self.alpha = None
|
||||
self.alpha_hat = None
|
||||
self.c1 = None
|
||||
self.c2 = None
|
||||
self.sigma = None
|
||||
|
||||
# dblocks
|
||||
self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2)
|
||||
self.dblocks = nn.ModuleList([])
|
||||
ic = config.model_params.y_conv_channels
|
||||
for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)):
|
||||
self.dblocks.append(DBlock(ic, oc, df))
|
||||
ic = oc
|
||||
|
||||
# film
|
||||
self.film = nn.ModuleList([])
|
||||
ic = config.model_params.y_conv_channels
|
||||
for oc in reversed(config.model_params.ublock_out_channels):
|
||||
self.film.append(FiLM(ic, oc))
|
||||
ic = oc
|
||||
|
||||
# ublocksn
|
||||
self.ublocks = nn.ModuleList([])
|
||||
ic = config.model_params.x_conv_channels
|
||||
for oc, uf, ud in zip(
|
||||
config.model_params.ublock_out_channels,
|
||||
config.model_params.upsample_factors,
|
||||
config.model_params.upsample_dilations,
|
||||
):
|
||||
self.ublocks.append(UBlock(ic, oc, uf, ud))
|
||||
ic = oc
|
||||
|
||||
self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1)
|
||||
self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1)
|
||||
|
||||
if config.model_params.use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x, spectrogram, noise_scale):
|
||||
shift_and_scale = []
|
||||
|
||||
x = self.y_conv(x)
|
||||
shift_and_scale.append(self.film[0](x, noise_scale))
|
||||
|
||||
for film, layer in zip(self.film[1:], self.dblocks):
|
||||
x = layer(x)
|
||||
shift_and_scale.append(film(x, noise_scale))
|
||||
|
||||
x = self.x_conv(spectrogram)
|
||||
for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)):
|
||||
x = layer(x, film_shift, film_scale)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
def load_noise_schedule(self, path):
|
||||
beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg
|
||||
self.compute_noise_level(beta)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, y_n=None):
|
||||
"""
|
||||
Shapes:
|
||||
x: :math:`[B, C , T]`
|
||||
y_n: :math:`[B, 1, T]`
|
||||
"""
|
||||
if y_n is None:
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
|
||||
else:
|
||||
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
|
||||
y_n = y_n.type_as(x)
|
||||
sqrt_alpha_hat = self.noise_level.to(x)
|
||||
for n in range(len(self.alpha) - 1, -1, -1):
|
||||
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
|
||||
if n > 0:
|
||||
z = torch.randn_like(y_n)
|
||||
y_n += self.sigma[n - 1] * z
|
||||
y_n.clamp_(-1.0, 1.0)
|
||||
return y_n
|
||||
|
||||
def compute_y_n(self, y_0):
|
||||
"""Compute noisy audio based on noise schedule"""
|
||||
self.noise_level = self.noise_level.to(y_0)
|
||||
if len(y_0.shape) == 3:
|
||||
y_0 = y_0.squeeze(1)
|
||||
s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]])
|
||||
l_a, l_b = self.noise_level[s], self.noise_level[s + 1]
|
||||
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
|
||||
noise_scale = noise_scale.unsqueeze(1)
|
||||
noise = torch.randn_like(y_0)
|
||||
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise
|
||||
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
|
||||
|
||||
def compute_noise_level(self, beta):
|
||||
"""Compute noise schedule parameters"""
|
||||
self.num_steps = len(beta)
|
||||
alpha = 1 - beta
|
||||
alpha_hat = np.cumprod(alpha)
|
||||
noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0)
|
||||
noise_level = alpha_hat**0.5
|
||||
|
||||
# pylint: disable=not-callable
|
||||
self.beta = torch.tensor(beta.astype(np.float32))
|
||||
self.alpha = torch.tensor(alpha.astype(np.float32))
|
||||
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
|
||||
self.noise_level = torch.tensor(noise_level.astype(np.float32))
|
||||
|
||||
self.c1 = 1 / self.alpha**0.5
|
||||
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5
|
||||
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for _, layer in enumerate(self.dblocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.film):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.ublocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
remove_parametrizations(self.x_conv, "weight")
|
||||
remove_parametrizations(self.out_conv, "weight")
|
||||
remove_parametrizations(self.y_conv, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
for _, layer in enumerate(self.dblocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.film):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.ublocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
self.x_conv = weight_norm(self.x_conv)
|
||||
self.out_conv = weight_norm(self.out_conv)
|
||||
self.y_conv = weight_norm(self.y_conv)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
if self.config.model_params.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
betas = np.linspace(
|
||||
config["test_noise_schedule"]["min_val"],
|
||||
config["test_noise_schedule"]["max_val"],
|
||||
config["test_noise_schedule"]["num_steps"],
|
||||
)
|
||||
self.compute_noise_level(betas)
|
||||
else:
|
||||
betas = np.linspace(
|
||||
config["train_noise_schedule"]["min_val"],
|
||||
config["train_noise_schedule"]["max_val"],
|
||||
config["train_noise_schedule"]["num_steps"],
|
||||
)
|
||||
self.compute_noise_level(betas)
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
# format data
|
||||
x = batch["input"]
|
||||
y = batch["waveform"]
|
||||
|
||||
# set noise scale
|
||||
noise, x_noisy, noise_scale = self.compute_y_n(y)
|
||||
|
||||
# forward pass
|
||||
noise_hat = self.forward(x_noisy, x, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
return {"model_output": noise_hat}, {"loss": loss}
|
||||
|
||||
def train_log( # pylint: disable=no-self-use
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log( # pylint: disable=no-self-use
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
|
||||
# setup noise schedule and inference
|
||||
ap = assets["audio_processor"]
|
||||
noise_schedule = self.config["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
samples = test_loader.dataset.load_test_samples(1)
|
||||
for sample in samples:
|
||||
x = sample[0]
|
||||
x = x[None, :, :].to(next(self.parameters()).device)
|
||||
y = sample[1]
|
||||
y = y[None, :]
|
||||
# compute voice
|
||||
y_pred = self.inference(x)
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_pred, y, ap, "test")
|
||||
# Sample audio
|
||||
sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
|
||||
return figures, {"test/audio": sample_voice}
|
||||
|
||||
def get_optimizer(self):
|
||||
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
|
||||
|
||||
def get_scheduler(self, optimizer):
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
|
||||
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
return torch.nn.L1Loss()
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: Dict) -> Dict:
|
||||
# return a whole audio segment
|
||||
m, y = batch[0], batch[1]
|
||||
y = y.unsqueeze(1)
|
||||
return {"input": m, "waveform": y}
|
||||
|
||||
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
|
||||
ap = assets["audio_processor"]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=samples,
|
||||
seq_len=self.config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=self.config.pad_short,
|
||||
conv_pad=self.config.conv_pad,
|
||||
is_training=not is_eval,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.batch_size,
|
||||
shuffle=num_gpus <= 1,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=unused-argument
|
||||
noise_schedule = self.config["train_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "WavegradConfig"):
|
||||
return Wavegrad(config)
|
||||
@@ -0,0 +1,646 @@
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
|
||||
|
||||
|
||||
def stream(string, variables):
|
||||
sys.stdout.write(f"\r{string}" % variables)
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
# relates https://github.com/pytorch/pytorch/issues/42305
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
||||
self.batch_norm1 = nn.BatchNorm1d(dims)
|
||||
self.batch_norm2 = nn.BatchNorm1d(dims)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.batch_norm1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.batch_norm2(x)
|
||||
return x + residual
|
||||
|
||||
|
||||
class MelResNet(nn.Module):
|
||||
def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
||||
super().__init__()
|
||||
k_size = pad * 2 + 1
|
||||
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
|
||||
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_res_blocks):
|
||||
self.layers.append(ResBlock(compute_dims))
|
||||
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.batch_norm(x)
|
||||
x = F.relu(x)
|
||||
for f in self.layers:
|
||||
x = f(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class Stretch2d(nn.Module):
|
||||
def __init__(self, x_scale, y_scale):
|
||||
super().__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
x = x.unsqueeze(-1).unsqueeze(3)
|
||||
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
|
||||
return x.view(b, c, h * self.y_scale, w * self.x_scale)
|
||||
|
||||
|
||||
class UpsampleNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feat_dims,
|
||||
upsample_scales,
|
||||
compute_dims,
|
||||
num_res_blocks,
|
||||
res_out_dims,
|
||||
pad,
|
||||
use_aux_net,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_scale = np.cumproduct(upsample_scales)[-1]
|
||||
self.indent = pad * self.total_scale
|
||||
self.use_aux_net = use_aux_net
|
||||
if use_aux_net:
|
||||
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
||||
self.resnet_stretch = Stretch2d(self.total_scale, 1)
|
||||
self.up_layers = nn.ModuleList()
|
||||
for scale in upsample_scales:
|
||||
k_size = (1, scale * 2 + 1)
|
||||
padding = (0, scale)
|
||||
stretch = Stretch2d(scale, 1)
|
||||
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
|
||||
conv.weight.data.fill_(1.0 / k_size[1])
|
||||
self.up_layers.append(stretch)
|
||||
self.up_layers.append(conv)
|
||||
|
||||
def forward(self, m):
|
||||
if self.use_aux_net:
|
||||
aux = self.resnet(m).unsqueeze(1)
|
||||
aux = self.resnet_stretch(aux)
|
||||
aux = aux.squeeze(1)
|
||||
aux = aux.transpose(1, 2)
|
||||
else:
|
||||
aux = None
|
||||
m = m.unsqueeze(1)
|
||||
for f in self.up_layers:
|
||||
m = f(m)
|
||||
m = m.squeeze(1)[:, :, self.indent : -self.indent]
|
||||
return m.transpose(1, 2), aux
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.pad = pad
|
||||
self.indent = pad * scale
|
||||
self.use_aux_net = use_aux_net
|
||||
self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
||||
|
||||
def forward(self, m):
|
||||
if self.use_aux_net:
|
||||
aux = self.resnet(m)
|
||||
aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True)
|
||||
aux = aux.transpose(1, 2)
|
||||
else:
|
||||
aux = None
|
||||
m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True)
|
||||
m = m[:, :, self.indent : -self.indent]
|
||||
m = m * 0.045 # empirically found
|
||||
|
||||
return m.transpose(1, 2), aux
|
||||
|
||||
|
||||
@dataclass
|
||||
class WavernnArgs(Coqpit):
|
||||
"""🐸 WaveRNN model arguments.
|
||||
|
||||
rnn_dims (int):
|
||||
Number of hidden channels in RNN layers. Defaults to 512.
|
||||
fc_dims (int):
|
||||
Number of hidden channels in fully-conntected layers. Defaults to 512.
|
||||
compute_dims (int):
|
||||
Number of hidden channels in the feature ResNet. Defaults to 128.
|
||||
res_out_dim (int):
|
||||
Number of hidden channels in the feature ResNet output. Defaults to 128.
|
||||
num_res_blocks (int):
|
||||
Number of residual blocks in the ResNet. Defaults to 10.
|
||||
use_aux_net (bool):
|
||||
enable/disable the feature ResNet. Defaults to True.
|
||||
use_upsample_net (bool):
|
||||
enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True.
|
||||
upsample_factors (list):
|
||||
Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```.
|
||||
mode (str):
|
||||
Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single
|
||||
Gaussian Distribution and `bits` for quantized bits as the model's output.
|
||||
mulaw (bool):
|
||||
enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults
|
||||
to `True`.
|
||||
pad (int):
|
||||
Padding applied to the input feature frames against the convolution layers of the feature network.
|
||||
Defaults to 2.
|
||||
"""
|
||||
|
||||
rnn_dims: int = 512
|
||||
fc_dims: int = 512
|
||||
compute_dims: int = 128
|
||||
res_out_dims: int = 128
|
||||
num_res_blocks: int = 10
|
||||
use_aux_net: bool = True
|
||||
use_upsample_net: bool = True
|
||||
upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8])
|
||||
mode: str = "mold" # mold [string], gauss [string], bits [int]
|
||||
mulaw: bool = True # apply mulaw if mode is bits
|
||||
pad: int = 2
|
||||
feat_dims: int = 80
|
||||
|
||||
|
||||
class Wavernn(BaseVocoder):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""🐸 WaveRNN model.
|
||||
Original paper - https://arxiv.org/abs/1802.08435
|
||||
Official implementation - https://github.com/fatchord/WaveRNN
|
||||
|
||||
Args:
|
||||
config (Coqpit): [description]
|
||||
|
||||
Raises:
|
||||
RuntimeError: [description]
|
||||
|
||||
Examples:
|
||||
>>> from TTS.vocoder.configs import WavernnConfig
|
||||
>>> config = WavernnConfig()
|
||||
>>> model = Wavernn(config)
|
||||
|
||||
Paper Abstract:
|
||||
Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to
|
||||
both estimating the data distribution and generating high-quality samples. Efficient sampling for this
|
||||
class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we
|
||||
describe a set of general techniques for reducing sampling time while maintaining high output quality.
|
||||
We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that
|
||||
matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it
|
||||
possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight
|
||||
pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of
|
||||
parameters, large sparse networks perform better than small dense networks and this relationship holds for
|
||||
sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample
|
||||
high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on
|
||||
subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple
|
||||
samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an
|
||||
orthogonal method for increasing sampling efficiency.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if isinstance(self.args.mode, int):
|
||||
self.n_classes = 2**self.args.mode
|
||||
elif self.args.mode == "mold":
|
||||
self.n_classes = 3 * 10
|
||||
elif self.args.mode == "gauss":
|
||||
self.n_classes = 2
|
||||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||
|
||||
self.ap = AudioProcessor(**config.audio.to_dict())
|
||||
self.aux_dims = self.args.res_out_dims // 4
|
||||
|
||||
if self.args.use_upsample_net:
|
||||
assert (
|
||||
np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length
|
||||
), " [!] upsample scales needs to be equal to hop_length"
|
||||
self.upsample = UpsampleNetwork(
|
||||
self.args.feat_dims,
|
||||
self.args.upsample_factors,
|
||||
self.args.compute_dims,
|
||||
self.args.num_res_blocks,
|
||||
self.args.res_out_dims,
|
||||
self.args.pad,
|
||||
self.args.use_aux_net,
|
||||
)
|
||||
else:
|
||||
self.upsample = Upsample(
|
||||
config.audio.hop_length,
|
||||
self.args.pad,
|
||||
self.args.num_res_blocks,
|
||||
self.args.feat_dims,
|
||||
self.args.compute_dims,
|
||||
self.args.res_out_dims,
|
||||
self.args.use_aux_net,
|
||||
)
|
||||
if self.args.use_aux_net:
|
||||
self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims)
|
||||
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims)
|
||||
self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims)
|
||||
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
|
||||
else:
|
||||
self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims)
|
||||
self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims)
|
||||
self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims)
|
||||
self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
|
||||
|
||||
def forward(self, x, mels):
|
||||
bsize = x.size(0)
|
||||
h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
|
||||
h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
|
||||
mels, aux = self.upsample(mels)
|
||||
|
||||
if self.args.use_aux_net:
|
||||
aux_idx = [self.aux_dims * i for i in range(5)]
|
||||
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
|
||||
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
|
||||
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
|
||||
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
|
||||
|
||||
x = (
|
||||
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
|
||||
if self.args.use_aux_net
|
||||
else torch.cat([x.unsqueeze(-1), mels], dim=2)
|
||||
)
|
||||
x = self.I(x)
|
||||
res = x
|
||||
self.rnn1.flatten_parameters()
|
||||
x, _ = self.rnn1(x, h1)
|
||||
|
||||
x = x + res
|
||||
res = x
|
||||
x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x
|
||||
self.rnn2.flatten_parameters()
|
||||
x, _ = self.rnn2(x, h2)
|
||||
|
||||
x = x + res
|
||||
x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc2(x))
|
||||
return self.fc3(x)
|
||||
|
||||
def inference(self, mels, batched=None, target=None, overlap=None):
|
||||
self.eval()
|
||||
output = []
|
||||
start = time.time()
|
||||
rnn1 = self.get_gru_cell(self.rnn1)
|
||||
rnn2 = self.get_gru_cell(self.rnn2)
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(mels, np.ndarray):
|
||||
mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
|
||||
|
||||
if mels.ndim == 2:
|
||||
mels = mels.unsqueeze(0)
|
||||
wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length
|
||||
|
||||
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both")
|
||||
mels, aux = self.upsample(mels.transpose(1, 2))
|
||||
|
||||
if batched:
|
||||
mels = self.fold_with_overlap(mels, target, overlap)
|
||||
if aux is not None:
|
||||
aux = self.fold_with_overlap(aux, target, overlap)
|
||||
|
||||
b_size, seq_len, _ = mels.size()
|
||||
|
||||
h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
|
||||
h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
|
||||
x = torch.zeros(b_size, 1).type_as(mels)
|
||||
|
||||
if self.args.use_aux_net:
|
||||
d = self.aux_dims
|
||||
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
|
||||
|
||||
for i in range(seq_len):
|
||||
m_t = mels[:, i, :]
|
||||
|
||||
if self.args.use_aux_net:
|
||||
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
|
||||
|
||||
x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1)
|
||||
x = self.I(x)
|
||||
h1 = rnn1(x, h1)
|
||||
|
||||
x = x + h1
|
||||
inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x
|
||||
h2 = rnn2(inp, h2)
|
||||
|
||||
x = x + h2
|
||||
x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x
|
||||
x = F.relu(self.fc2(x))
|
||||
|
||||
logits = self.fc3(x)
|
||||
|
||||
if self.args.mode == "mold":
|
||||
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif self.args.mode == "gauss":
|
||||
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif isinstance(self.args.mode, int):
|
||||
posterior = F.softmax(logits, dim=1)
|
||||
distrib = torch.distributions.Categorical(posterior)
|
||||
|
||||
sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
|
||||
output.append(sample)
|
||||
x = sample.unsqueeze(-1)
|
||||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||
|
||||
if i % 100 == 0:
|
||||
self.gen_display(i, seq_len, b_size, start)
|
||||
|
||||
output = torch.stack(output).transpose(0, 1)
|
||||
output = output.cpu()
|
||||
if batched:
|
||||
output = output.numpy()
|
||||
output = output.astype(np.float64)
|
||||
|
||||
output = self.xfade_and_unfold(output, target, overlap)
|
||||
else:
|
||||
output = output[0]
|
||||
|
||||
if self.args.mulaw and isinstance(self.args.mode, int):
|
||||
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
|
||||
|
||||
# Fade-out at the end to avoid signal cutting out suddenly
|
||||
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
||||
output = output[:wave_len]
|
||||
|
||||
if wave_len > len(fade_out):
|
||||
output[-20 * self.config.audio.hop_length :] *= fade_out
|
||||
|
||||
self.train()
|
||||
return output
|
||||
|
||||
def gen_display(self, i, seq_len, b_size, start):
|
||||
gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
|
||||
realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate
|
||||
stream(
|
||||
"%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ",
|
||||
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
|
||||
)
|
||||
|
||||
def fold_with_overlap(self, x, target, overlap):
|
||||
"""Fold the tensor with overlap for quick batched inference.
|
||||
Overlap will be used for crossfading in xfade_and_unfold()
|
||||
Args:
|
||||
x (tensor) : Upsampled conditioning features.
|
||||
shape=(1, timesteps, features)
|
||||
target (int) : Target timesteps for each index of batch
|
||||
overlap (int) : Timesteps for both xfade and rnn warmup
|
||||
Return:
|
||||
(tensor) : shape=(num_folds, target + 2 * overlap, features)
|
||||
Details:
|
||||
x = [[h1, h2, ... hn]]
|
||||
Where each h is a vector of conditioning features
|
||||
Eg: target=2, overlap=1 with x.size(1)=10
|
||||
folded = [[h1, h2, h3, h4],
|
||||
[h4, h5, h6, h7],
|
||||
[h7, h8, h9, h10]]
|
||||
"""
|
||||
|
||||
_, total_len, features = x.size()
|
||||
|
||||
# Calculate variables needed
|
||||
num_folds = (total_len - overlap) // (target + overlap)
|
||||
extended_len = num_folds * (overlap + target) + overlap
|
||||
remaining = total_len - extended_len
|
||||
|
||||
# Pad if some time steps poking out
|
||||
if remaining != 0:
|
||||
num_folds += 1
|
||||
padding = target + 2 * overlap - remaining
|
||||
x = self.pad_tensor(x, padding, side="after")
|
||||
|
||||
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
|
||||
|
||||
# Get the values for the folded tensor
|
||||
for i in range(num_folds):
|
||||
start = i * (target + overlap)
|
||||
end = start + target + 2 * overlap
|
||||
folded[i] = x[:, start:end, :]
|
||||
|
||||
return folded
|
||||
|
||||
@staticmethod
|
||||
def get_gru_cell(gru):
|
||||
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
||||
gru_cell.weight_hh.data = gru.weight_hh_l0.data
|
||||
gru_cell.weight_ih.data = gru.weight_ih_l0.data
|
||||
gru_cell.bias_hh.data = gru.bias_hh_l0.data
|
||||
gru_cell.bias_ih.data = gru.bias_ih_l0.data
|
||||
return gru_cell
|
||||
|
||||
@staticmethod
|
||||
def pad_tensor(x, pad, side="both"):
|
||||
# NB - this is just a quick method i need right now
|
||||
# i.e., it won't generalise to other shapes/dims
|
||||
b, t, c = x.size()
|
||||
total = t + 2 * pad if side == "both" else t + pad
|
||||
padded = torch.zeros(b, total, c).to(x.device)
|
||||
if side in ("before", "both"):
|
||||
padded[:, pad : pad + t, :] = x
|
||||
elif side == "after":
|
||||
padded[:, :t, :] = x
|
||||
return padded
|
||||
|
||||
@staticmethod
|
||||
def xfade_and_unfold(y, target, overlap):
|
||||
"""Applies a crossfade and unfolds into a 1d array.
|
||||
Args:
|
||||
y (ndarry) : Batched sequences of audio samples
|
||||
shape=(num_folds, target + 2 * overlap)
|
||||
dtype=np.float64
|
||||
overlap (int) : Timesteps for both xfade and rnn warmup
|
||||
Return:
|
||||
(ndarry) : audio samples in a 1d array
|
||||
shape=(total_len)
|
||||
dtype=np.float64
|
||||
Details:
|
||||
y = [[seq1],
|
||||
[seq2],
|
||||
[seq3]]
|
||||
Apply a gain envelope at both ends of the sequences
|
||||
y = [[seq1_in, seq1_target, seq1_out],
|
||||
[seq2_in, seq2_target, seq2_out],
|
||||
[seq3_in, seq3_target, seq3_out]]
|
||||
Stagger and add up the groups of samples:
|
||||
[seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
|
||||
"""
|
||||
|
||||
num_folds, length = y.shape
|
||||
target = length - 2 * overlap
|
||||
total_len = num_folds * (target + overlap) + overlap
|
||||
|
||||
# Need some silence for the rnn warmup
|
||||
silence_len = overlap // 2
|
||||
fade_len = overlap - silence_len
|
||||
silence = np.zeros((silence_len), dtype=np.float64)
|
||||
|
||||
# Equal power crossfade
|
||||
t = np.linspace(-1, 1, fade_len, dtype=np.float64)
|
||||
fade_in = np.sqrt(0.5 * (1 + t))
|
||||
fade_out = np.sqrt(0.5 * (1 - t))
|
||||
|
||||
# Concat the silence to the fades
|
||||
fade_in = np.concatenate([silence, fade_in])
|
||||
fade_out = np.concatenate([fade_out, silence])
|
||||
|
||||
# Apply the gain to the overlap samples
|
||||
y[:, :overlap] *= fade_in
|
||||
y[:, -overlap:] *= fade_out
|
||||
|
||||
unfolded = np.zeros((total_len), dtype=np.float64)
|
||||
|
||||
# Loop to add up all the samples
|
||||
for i in range(num_folds):
|
||||
start = i * (target + overlap)
|
||||
end = start + target + 2 * overlap
|
||||
unfolded[start:end] += y[i]
|
||||
|
||||
return unfolded
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
mels = batch["input"]
|
||||
waveform = batch["waveform"]
|
||||
waveform_coarse = batch["waveform_coarse"]
|
||||
|
||||
y_hat = self.forward(waveform, mels)
|
||||
if isinstance(self.args.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
waveform_coarse = waveform_coarse.float()
|
||||
waveform_coarse = waveform_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss_dict = criterion(y_hat, waveform_coarse)
|
||||
return {"model_output": y_hat}, loss_dict
|
||||
|
||||
def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
@torch.no_grad()
|
||||
def test(
|
||||
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, Dict]:
|
||||
ap = self.ap
|
||||
figures = {}
|
||||
audios = {}
|
||||
samples = test_loader.dataset.load_test_samples(1)
|
||||
for idx, sample in enumerate(samples):
|
||||
x = torch.FloatTensor(sample[0])
|
||||
x = x.to(next(self.parameters()).device)
|
||||
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
|
||||
x_hat = ap.melspectrogram(y_hat)
|
||||
figures.update(
|
||||
{
|
||||
f"test_{idx}/ground_truth": plot_spectrogram(x.T),
|
||||
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
|
||||
}
|
||||
)
|
||||
audios.update({f"test_{idx}/audio": y_hat})
|
||||
# audios.update({f"real_{idx}/audio": y_hat})
|
||||
return figures, audios
|
||||
|
||||
def test_log(
|
||||
self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
figures, audios = outputs
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: Dict) -> Dict:
|
||||
waveform = batch[0]
|
||||
mels = batch[1]
|
||||
waveform_coarse = batch[2]
|
||||
return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse}
|
||||
|
||||
def get_data_loader( # pylint: disable=no-self-use
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: True,
|
||||
samples: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
):
|
||||
ap = self.ap
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=samples,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=config.model_args.pad,
|
||||
mode=config.model_args.mode,
|
||||
mulaw=config.model_args.mulaw,
|
||||
is_training=not is_eval,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_eval else config.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
collate_fn=dataset.collate,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_criterion(self):
|
||||
# define train functions
|
||||
return WaveRNNLoss(self.args.mode)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "WavernnConfig"):
|
||||
return Wavernn(config)
|
||||
Reference in New Issue
Block a user