Add files via upload

This commit is contained in:
Sam Khoze
2024-06-18 13:21:08 -07:00
committed by GitHub
parent 6af40bc6cf
commit dc8b8bca5a
97 changed files with 10910 additions and 0 deletions
View File
+56
View File
@@ -0,0 +1,56 @@
from torch import nn
from torch.nn.utils.parametrize import remove_parametrizations
# pylint: disable=dangerous-default-value
class ResStack(nn.Module):
def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]):
super().__init__()
resstack = []
for dilation in dilations:
resstack += [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)
),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(padding),
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
]
self.resstack = nn.Sequential(*resstack)
self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
def forward(self, x):
x1 = self.shortcut(x)
x2 = self.resstack(x)
return x1 + x2
def remove_weight_norm(self):
remove_parametrizations(self.shortcut, "weight")
remove_parametrizations(self.resstack[2], "weight")
remove_parametrizations(self.resstack[5], "weight")
remove_parametrizations(self.resstack[8], "weight")
remove_parametrizations(self.resstack[11], "weight")
remove_parametrizations(self.resstack[14], "weight")
remove_parametrizations(self.resstack[17], "weight")
class MRF(nn.Module):
def __init__(self, kernels, channel, dilations=[1, 3, 5]): # # pylint: disable=dangerous-default-value
super().__init__()
self.resblock1 = ResStack(kernels[0], channel, 0, dilations)
self.resblock2 = ResStack(kernels[1], channel, 6, dilations)
self.resblock3 = ResStack(kernels[2], channel, 12, dilations)
def forward(self, x):
x1 = self.resblock1(x)
x2 = self.resblock2(x)
x3 = self.resblock3(x)
return x1 + x2 + x3
def remove_weight_norm(self):
self.resblock1.remove_weight_norm()
self.resblock2.remove_weight_norm()
self.resblock3.remove_weight_norm()
+368
View File
@@ -0,0 +1,368 @@
from typing import Dict, Union
import torch
from torch import nn
from torch.nn import functional as F
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
#################################
# GENERATOR LOSSES
#################################
class STFTLoss(nn.Module):
"""STFT loss. Input generate and real waveforms are converted
to spectrograms compared with L1 and Spectral convergence losses.
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
def __init__(self, n_fft, hop_length, win_length):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.stft = TorchSTFT(n_fft, hop_length, win_length)
def forward(self, y_hat, y):
y_hat_M = self.stft(y_hat)
y_M = self.stft(y)
# magnitude loss
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
# spectral convergence loss
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
return loss_mag, loss_sc
class MultiScaleSTFTLoss(torch.nn.Module):
"""Multi-scale STFT loss. Input generate and real waveforms are converted
to spectrograms compared with L1 and Spectral convergence losses.
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)):
super().__init__()
self.loss_funcs = torch.nn.ModuleList()
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
def forward(self, y_hat, y):
N = len(self.loss_funcs)
loss_sc = 0
loss_mag = 0
for f in self.loss_funcs:
lm, lsc = f(y_hat, y)
loss_mag += lm
loss_sc += lsc
loss_sc /= N
loss_mag /= N
return loss_mag, loss_sc
class L1SpecLoss(nn.Module):
"""L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""
def __init__(
self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True
):
super().__init__()
self.use_mel = use_mel
self.stft = TorchSTFT(
n_fft,
hop_length,
win_length,
sample_rate=sample_rate,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax,
n_mels=n_mels,
use_mel=use_mel,
)
def forward(self, y_hat, y):
y_hat_M = self.stft(y_hat)
y_M = self.stft(y)
# magnitude loss
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
return loss_mag
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
"""Multiscale STFT loss for multi band model outputs.
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""
# pylint: disable=no-self-use
def forward(self, y_hat, y):
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
y = y.view(-1, 1, y.shape[2])
return super().forward(y_hat.squeeze(1), y.squeeze(1))
class MSEGLoss(nn.Module):
"""Mean Squared Generator Loss"""
# pylint: disable=no-self-use
def forward(self, score_real):
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
return loss_fake
class HingeGLoss(nn.Module):
"""Hinge Discriminator Loss"""
# pylint: disable=no-self-use
def forward(self, score_real):
# TODO: this might be wrong
loss_fake = torch.mean(F.relu(1.0 - score_real))
return loss_fake
##################################
# DISCRIMINATOR LOSSES
##################################
class MSEDLoss(nn.Module):
"""Mean Squared Discriminator Loss"""
def __init__(
self,
):
super().__init__()
self.loss_func = nn.MSELoss()
# pylint: disable=no-self-use
def forward(self, score_fake, score_real):
loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape))
loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape))
loss_d = loss_real + loss_fake
return loss_d, loss_real, loss_fake
class HingeDLoss(nn.Module):
"""Hinge Discriminator Loss"""
# pylint: disable=no-self-use
def forward(self, score_fake, score_real):
loss_real = torch.mean(F.relu(1.0 - score_real))
loss_fake = torch.mean(F.relu(1.0 + score_fake))
loss_d = loss_real + loss_fake
return loss_d, loss_real, loss_fake
class MelganFeatureLoss(nn.Module):
def __init__(
self,
):
super().__init__()
self.loss_func = nn.L1Loss()
# pylint: disable=no-self-use
def forward(self, fake_feats, real_feats):
loss_feats = 0
num_feats = 0
for idx, _ in enumerate(fake_feats):
for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]):
loss_feats += self.loss_func(fake_feat, real_feat)
num_feats += 1
loss_feats = loss_feats / num_feats
return loss_feats
#####################################
# LOSS WRAPPERS
#####################################
def _apply_G_adv_loss(scores_fake, loss_func):
"""Compute G adversarial loss function
and normalize values"""
adv_loss = 0
if isinstance(scores_fake, list):
for score_fake in scores_fake:
fake_loss = loss_func(score_fake)
adv_loss += fake_loss
adv_loss /= len(scores_fake)
else:
fake_loss = loss_func(scores_fake)
adv_loss = fake_loss
return adv_loss
def _apply_D_loss(scores_fake, scores_real, loss_func):
"""Compute D loss func and normalize loss values"""
loss = 0
real_loss = 0
fake_loss = 0
if isinstance(scores_fake, list):
# multi-scale loss
for score_fake, score_real in zip(scores_fake, scores_real):
total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real)
loss += total_loss
real_loss += real_loss_
fake_loss += fake_loss_
# normalize loss values with number of scales (discriminators)
loss /= len(scores_fake)
real_loss /= len(scores_real)
fake_loss /= len(scores_fake)
else:
# single scale loss
total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real)
loss = total_loss
return loss, real_loss, fake_loss
##################################
# MODEL LOSSES
##################################
class GeneratorLoss(nn.Module):
"""Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes
losses. It allows to experiment with different combinations of loss functions with different models by just
changing configurations.
Args:
C (AttrDict): model configuration.
"""
def __init__(self, C):
super().__init__()
assert not (
C.use_mse_gan_loss and C.use_hinge_gan_loss
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False
self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False
self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False
self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False
self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False
self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False
self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0
self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0
self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0
self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0
self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0
if C.use_stft_loss:
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
if C.use_subband_stft_loss:
self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params)
if C.use_mse_gan_loss:
self.mse_loss = MSEGLoss()
if C.use_hinge_gan_loss:
self.hinge_loss = HingeGLoss()
if C.use_feat_match_loss:
self.feat_match_loss = MelganFeatureLoss()
if C.use_l1_spec_loss:
assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"]
self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)
def forward(
self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None
):
gen_loss = 0
adv_loss = 0
return_dict = {}
# STFT Loss
if self.use_stft_loss:
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1))
return_dict["G_stft_loss_mg"] = stft_loss_mg
return_dict["G_stft_loss_sc"] = stft_loss_sc
gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
# L1 Spec loss
if self.use_l1_spec_loss:
l1_spec_loss = self.l1_spec_loss(y_hat, y)
return_dict["G_l1_spec_loss"] = l1_spec_loss
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
# subband STFT Loss
if self.use_subband_stft_loss:
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg
return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc
gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
# multiscale MSE adversarial loss
if self.use_mse_gan_loss and scores_fake is not None:
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
return_dict["G_mse_fake_loss"] = mse_fake_loss
adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss
# multiscale Hinge adversarial loss
if self.use_hinge_gan_loss and not scores_fake is not None:
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
return_dict["G_hinge_fake_loss"] = hinge_fake_loss
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
# Feature Matching Loss
if self.use_feat_match_loss and not feats_fake is None:
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
return_dict["G_feat_match_loss"] = feat_match_loss
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
return_dict["loss"] = gen_loss + adv_loss
return_dict["G_gen_loss"] = gen_loss
return_dict["G_adv_loss"] = adv_loss
return return_dict
class DiscriminatorLoss(nn.Module):
"""Like ```GeneratorLoss```"""
def __init__(self, C):
super().__init__()
assert not (
C.use_mse_gan_loss and C.use_hinge_gan_loss
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_mse_gan_loss = C.use_mse_gan_loss
self.use_hinge_gan_loss = C.use_hinge_gan_loss
if C.use_mse_gan_loss:
self.mse_loss = MSEDLoss()
if C.use_hinge_gan_loss:
self.hinge_loss = HingeDLoss()
def forward(self, scores_fake, scores_real):
loss = 0
return_dict = {}
if self.use_mse_gan_loss:
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss
)
return_dict["D_mse_gan_loss"] = mse_D_loss
return_dict["D_mse_gan_real_loss"] = mse_D_real_loss
return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss
loss += mse_D_loss
if self.use_hinge_gan_loss:
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss
)
return_dict["D_hinge_gan_loss"] = hinge_D_loss
return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
loss += hinge_D_loss
return_dict["loss"] = loss
return return_dict
class WaveRNNLoss(nn.Module):
def __init__(self, wave_rnn_mode: Union[str, int]):
super().__init__()
if wave_rnn_mode == "mold":
self.loss_func = discretized_mix_logistic_loss
elif wave_rnn_mode == "gauss":
self.loss_func = gaussian_loss
elif isinstance(wave_rnn_mode, int):
self.loss_func = torch.nn.CrossEntropyLoss()
else:
raise ValueError(" [!] Unknown mode for Wavernn.")
def forward(self, y_hat, y) -> Dict:
loss = self.loss_func(y_hat, y)
return {"loss": loss}
+198
View File
@@ -0,0 +1,198 @@
import torch
import torch.nn.functional as F
class KernelPredictor(torch.nn.Module):
"""Kernel predictor for the location-variable convolutions"""
def __init__( # pylint: disable=dangerous-default-value
self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
):
"""
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int):
kpnet_
"""
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
l_b = conv_out_channels * conv_layers
padding = (kpnet_conv_size - 1) // 2
self.input_conv = torch.nn.Sequential(
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_conv = torch.nn.Sequential(
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True)
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True)
def forward(self, c):
"""
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
"""
batch, _, cond_length = c.shape
c = self.input_conv(c)
c = c + self.residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(
batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length
)
bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length)
return kernels, bias
class LVCBlock(torch.nn.Module):
"""the location-variable convolutions"""
def __init__(
self,
in_channels,
cond_channels,
upsample_ratio,
conv_layers=4,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
):
super().__init__()
self.cond_hop_length = cond_hop_length
self.conv_layers = conv_layers
self.conv_kernel_size = conv_kernel_size
self.convs = torch.nn.ModuleList()
self.upsample = torch.nn.ConvTranspose1d(
in_channels,
in_channels,
kernel_size=upsample_ratio * 2,
stride=upsample_ratio,
padding=upsample_ratio // 2 + upsample_ratio % 2,
output_padding=upsample_ratio % 2,
)
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=conv_layers,
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout,
)
for i in range(conv_layers):
padding = (3**i) * int((conv_kernel_size - 1) / 2)
conv = torch.nn.Conv1d(
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i
)
self.convs.append(conv)
def forward(self, x, c):
"""forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
"""
in_channels = x.shape[1]
kernels, bias = self.kernel_predictor(c)
x = F.leaky_relu(x, 0.2)
x = self.upsample(x)
for i in range(self.conv_layers):
y = F.leaky_relu(x, 0.2)
y = self.convs[i](y)
y = F.leaky_relu(y, 0.2)
k = kernels[:, i, :, :, :, :]
b = bias[:, i, :, :]
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
return x
@staticmethod
def location_variable_convolution(x, kernel, bias, dilation, hop_size):
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
"""
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (
kernel_length * hop_size
), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0)
x = x.unfold(
3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o + bias.unsqueeze(-1).unsqueeze(-1)
o = o.contiguous().view(batch, out_channels, -1)
return o
+43
View File
@@ -0,0 +1,43 @@
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
class ResidualStack(nn.Module):
def __init__(self, channels, num_res_blocks, kernel_size):
super().__init__()
assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
base_padding = (kernel_size - 1) // 2
self.blocks = nn.ModuleList()
for idx in range(num_res_blocks):
layer_kernel_size = kernel_size
layer_dilation = layer_kernel_size**idx
layer_padding = base_padding * layer_dilation
self.blocks += [
nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(layer_padding),
weight_norm(
nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True)
),
nn.LeakyReLU(0.2),
weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
)
]
self.shortcuts = nn.ModuleList(
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
)
def forward(self, x):
for block, shortcut in zip(self.blocks, self.shortcuts):
x = shortcut(x) + block(x)
return x
def remove_weight_norm(self):
for block, shortcut in zip(self.blocks, self.shortcuts):
remove_parametrizations(block[2], "weight")
remove_parametrizations(block[4], "weight")
remove_parametrizations(shortcut, "weight")
+77
View File
@@ -0,0 +1,77 @@
import torch
from torch.nn import functional as F
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(
self,
kernel_size=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
dilation=1,
bias=True,
use_causal_conv=False,
):
super().__init__()
self.dropout = dropout
# no future time stamps available
if use_causal_conv:
padding = (kernel_size - 1) * dilation
else:
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
padding = (kernel_size - 1) // 2 * dilation
self.use_causal_conv = use_causal_conv
# dilation conv
self.conv = torch.nn.Conv1d(
res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias
)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
else:
self.conv1x1_aux = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
def forward(self, x, c):
"""
x: B x D_res x T
c: B x D_aux x T
"""
residual = x
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv(x)
# remove future time steps if use_causal_conv conv
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
# split into two part for gated activation
splitdim = 1
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
# local conditioning
if c is not None:
assert self.conv1x1_aux is not None
c = self.conv1x1_aux(c)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ca, xb + cb
x = torch.tanh(xa) * torch.sigmoid(xb)
# for skip connection
s = self.conv1x1_skip(x)
# for residual connection
x = (self.conv1x1_out(x) + residual) * (0.5**2)
return x, s
+53
View File
@@ -0,0 +1,53 @@
import numpy as np
import torch
import torch.nn.functional as F
from scipy import signal as sig
# adapted from
# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan
class PQMF(torch.nn.Module):
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
super().__init__()
self.N = N
self.taps = taps
self.cutoff = cutoff
self.beta = beta
QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta))
H = np.zeros((N, len(QMF)))
G = np.zeros((N, len(QMF)))
for k in range(N):
constant_factor = (
(2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2))
) # TODO: (taps - 1) -> taps
phase = (-1) ** k * np.pi / 4
H[k] = 2 * QMF * np.cos(constant_factor + phase)
G[k] = 2 * QMF * np.cos(constant_factor - phase)
H = torch.from_numpy(H[:, None, :]).float()
G = torch.from_numpy(G[None, :, :]).float()
self.register_buffer("H", H)
self.register_buffer("G", G)
updown_filter = torch.zeros((N, N, N)).float()
for k in range(N):
updown_filter[k, k, 0] = 1.0
self.register_buffer("updown_filter", updown_filter)
self.N = N
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def forward(self, x):
return self.analysis(x)
def analysis(self, x):
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
def synthesis(self, x):
x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N)
x = F.conv1d(x, self.G, padding=self.taps // 2)
return x
+640
View File
@@ -0,0 +1,640 @@
0.0000000e+000
-5.5252865e-004
-5.6176926e-004
-4.9475181e-004
-4.8752280e-004
-4.8937912e-004
-5.0407143e-004
-5.2265643e-004
-5.4665656e-004
-5.6778026e-004
-5.8709305e-004
-6.1327474e-004
-6.3124935e-004
-6.5403334e-004
-6.7776908e-004
-6.9416146e-004
-7.1577365e-004
-7.2550431e-004
-7.4409419e-004
-7.4905981e-004
-7.6813719e-004
-7.7248486e-004
-7.8343323e-004
-7.7798695e-004
-7.8036647e-004
-7.8014496e-004
-7.7579773e-004
-7.6307936e-004
-7.5300014e-004
-7.3193572e-004
-7.2153920e-004
-6.9179375e-004
-6.6504151e-004
-6.3415949e-004
-5.9461189e-004
-5.5645764e-004
-5.1455722e-004
-4.6063255e-004
-4.0951215e-004
-3.5011759e-004
-2.8969812e-004
-2.0983373e-004
-1.4463809e-004
-6.1733441e-005
1.3494974e-005
1.0943831e-004
2.0430171e-004
2.9495311e-004
4.0265402e-004
5.1073885e-004
6.2393761e-004
7.4580259e-004
8.6084433e-004
9.8859883e-004
1.1250155e-003
1.2577885e-003
1.3902495e-003
1.5443220e-003
1.6868083e-003
1.8348265e-003
1.9841141e-003
2.1461584e-003
2.3017255e-003
2.4625617e-003
2.6201759e-003
2.7870464e-003
2.9469448e-003
3.1125421e-003
3.2739613e-003
3.4418874e-003
3.6008268e-003
3.7603923e-003
3.9207432e-003
4.0819753e-003
4.2264269e-003
4.3730720e-003
4.5209853e-003
4.6606461e-003
4.7932561e-003
4.9137604e-003
5.0393023e-003
5.1407354e-003
5.2461166e-003
5.3471681e-003
5.4196776e-003
5.4876040e-003
5.5475715e-003
5.5938023e-003
5.6220643e-003
5.6455197e-003
5.6389200e-003
5.6266114e-003
5.5917129e-003
5.5404364e-003
5.4753783e-003
5.3838976e-003
5.2715759e-003
5.1382275e-003
4.9839688e-003
4.8109469e-003
4.6039530e-003
4.3801862e-003
4.1251642e-003
3.8456408e-003
3.5401247e-003
3.2091886e-003
2.8446758e-003
2.4508540e-003
2.0274176e-003
1.5784683e-003
1.0902329e-003
5.8322642e-004
2.7604519e-005
-5.4642809e-004
-1.1568136e-003
-1.8039473e-003
-2.4826724e-003
-3.1933778e-003
-3.9401124e-003
-4.7222596e-003
-5.5337211e-003
-6.3792293e-003
-7.2615817e-003
-8.1798233e-003
-9.1325330e-003
-1.0115022e-002
-1.1131555e-002
-1.2185000e-002
-1.3271822e-002
-1.4390467e-002
-1.5540555e-002
-1.6732471e-002
-1.7943338e-002
-1.9187243e-002
-2.0453179e-002
-2.1746755e-002
-2.3068017e-002
-2.4416099e-002
-2.5787585e-002
-2.7185943e-002
-2.8607217e-002
-3.0050266e-002
-3.1501761e-002
-3.2975408e-002
-3.4462095e-002
-3.5969756e-002
-3.7481285e-002
-3.9005368e-002
-4.0534917e-002
-4.2064909e-002
-4.3609754e-002
-4.5148841e-002
-4.6684303e-002
-4.8216572e-002
-4.9738576e-002
-5.1255616e-002
-5.2763075e-002
-5.4245277e-002
-5.5717365e-002
-5.7161645e-002
-5.8591568e-002
-5.9983748e-002
-6.1345517e-002
-6.2685781e-002
-6.3971590e-002
-6.5224711e-002
-6.6436751e-002
-6.7607599e-002
-6.8704383e-002
-6.9763024e-002
-7.0762871e-002
-7.1700267e-002
-7.2568258e-002
-7.3362026e-002
-7.4100364e-002
-7.4745256e-002
-7.5313734e-002
-7.5800836e-002
-7.6199248e-002
-7.6499217e-002
-7.6709349e-002
-7.6817398e-002
-7.6823001e-002
-7.6720492e-002
-7.6505072e-002
-7.6174832e-002
-7.5730576e-002
-7.5157626e-002
-7.4466439e-002
-7.3640601e-002
-7.2677464e-002
-7.1582636e-002
-7.0353307e-002
-6.8966401e-002
-6.7452502e-002
-6.5769067e-002
-6.3944481e-002
-6.1960278e-002
-5.9816657e-002
-5.7515269e-002
-5.5046003e-002
-5.2409382e-002
-4.9597868e-002
-4.6630331e-002
-4.3476878e-002
-4.0145828e-002
-3.6641812e-002
-3.2958393e-002
-2.9082401e-002
-2.5030756e-002
-2.0799707e-002
-1.6370126e-002
-1.1762383e-002
-6.9636862e-003
-1.9765601e-003
3.2086897e-003
8.5711749e-003
1.4128883e-002
1.9883413e-002
2.5822729e-002
3.1953127e-002
3.8277657e-002
4.4780682e-002
5.1480418e-002
5.8370533e-002
6.5440985e-002
7.2694330e-002
8.0137293e-002
8.7754754e-002
9.5553335e-002
1.0353295e-001
1.1168269e-001
1.2000780e-001
1.2850029e-001
1.3715518e-001
1.4597665e-001
1.5496071e-001
1.6409589e-001
1.7338082e-001
1.8281725e-001
1.9239667e-001
2.0212502e-001
2.1197359e-001
2.2196527e-001
2.3206909e-001
2.4230169e-001
2.5264803e-001
2.6310533e-001
2.7366340e-001
2.8432142e-001
2.9507167e-001
3.0590986e-001
3.1682789e-001
3.2781137e-001
3.3887227e-001
3.4999141e-001
3.6115899e-001
3.7237955e-001
3.8363500e-001
3.9492118e-001
4.0623177e-001
4.1756969e-001
4.2891199e-001
4.4025538e-001
4.5159965e-001
4.6293081e-001
4.7424532e-001
4.8552531e-001
4.9677083e-001
5.0798175e-001
5.1912350e-001
5.3022409e-001
5.4125534e-001
5.5220513e-001
5.6307891e-001
5.7385241e-001
5.8454032e-001
5.9511231e-001
6.0557835e-001
6.1591099e-001
6.2612427e-001
6.3619801e-001
6.4612697e-001
6.5590163e-001
6.6551399e-001
6.7496632e-001
6.8423533e-001
6.9332824e-001
7.0223887e-001
7.1094104e-001
7.1944626e-001
7.2774489e-001
7.3582118e-001
7.4368279e-001
7.5131375e-001
7.5870808e-001
7.6586749e-001
7.7277809e-001
7.7942875e-001
7.8583531e-001
7.9197358e-001
7.9784664e-001
8.0344858e-001
8.0876950e-001
8.1381913e-001
8.1857760e-001
8.2304199e-001
8.2722753e-001
8.3110385e-001
8.3469374e-001
8.3797173e-001
8.4095414e-001
8.4362383e-001
8.4598185e-001
8.4803158e-001
8.4978052e-001
8.5119715e-001
8.5230470e-001
8.5310209e-001
8.5357206e-001
8.5373856e-001
8.5357206e-001
8.5310209e-001
8.5230470e-001
8.5119715e-001
8.4978052e-001
8.4803158e-001
8.4598185e-001
8.4362383e-001
8.4095414e-001
8.3797173e-001
8.3469374e-001
8.3110385e-001
8.2722753e-001
8.2304199e-001
8.1857760e-001
8.1381913e-001
8.0876950e-001
8.0344858e-001
7.9784664e-001
7.9197358e-001
7.8583531e-001
7.7942875e-001
7.7277809e-001
7.6586749e-001
7.5870808e-001
7.5131375e-001
7.4368279e-001
7.3582118e-001
7.2774489e-001
7.1944626e-001
7.1094104e-001
7.0223887e-001
6.9332824e-001
6.8423533e-001
6.7496632e-001
6.6551399e-001
6.5590163e-001
6.4612697e-001
6.3619801e-001
6.2612427e-001
6.1591099e-001
6.0557835e-001
5.9511231e-001
5.8454032e-001
5.7385241e-001
5.6307891e-001
5.5220513e-001
5.4125534e-001
5.3022409e-001
5.1912350e-001
5.0798175e-001
4.9677083e-001
4.8552531e-001
4.7424532e-001
4.6293081e-001
4.5159965e-001
4.4025538e-001
4.2891199e-001
4.1756969e-001
4.0623177e-001
3.9492118e-001
3.8363500e-001
3.7237955e-001
3.6115899e-001
3.4999141e-001
3.3887227e-001
3.2781137e-001
3.1682789e-001
3.0590986e-001
2.9507167e-001
2.8432142e-001
2.7366340e-001
2.6310533e-001
2.5264803e-001
2.4230169e-001
2.3206909e-001
2.2196527e-001
2.1197359e-001
2.0212502e-001
1.9239667e-001
1.8281725e-001
1.7338082e-001
1.6409589e-001
1.5496071e-001
1.4597665e-001
1.3715518e-001
1.2850029e-001
1.2000780e-001
1.1168269e-001
1.0353295e-001
9.5553335e-002
8.7754754e-002
8.0137293e-002
7.2694330e-002
6.5440985e-002
5.8370533e-002
5.1480418e-002
4.4780682e-002
3.8277657e-002
3.1953127e-002
2.5822729e-002
1.9883413e-002
1.4128883e-002
8.5711749e-003
3.2086897e-003
-1.9765601e-003
-6.9636862e-003
-1.1762383e-002
-1.6370126e-002
-2.0799707e-002
-2.5030756e-002
-2.9082401e-002
-3.2958393e-002
-3.6641812e-002
-4.0145828e-002
-4.3476878e-002
-4.6630331e-002
-4.9597868e-002
-5.2409382e-002
-5.5046003e-002
-5.7515269e-002
-5.9816657e-002
-6.1960278e-002
-6.3944481e-002
-6.5769067e-002
-6.7452502e-002
-6.8966401e-002
-7.0353307e-002
-7.1582636e-002
-7.2677464e-002
-7.3640601e-002
-7.4466439e-002
-7.5157626e-002
-7.5730576e-002
-7.6174832e-002
-7.6505072e-002
-7.6720492e-002
-7.6823001e-002
-7.6817398e-002
-7.6709349e-002
-7.6499217e-002
-7.6199248e-002
-7.5800836e-002
-7.5313734e-002
-7.4745256e-002
-7.4100364e-002
-7.3362026e-002
-7.2568258e-002
-7.1700267e-002
-7.0762871e-002
-6.9763024e-002
-6.8704383e-002
-6.7607599e-002
-6.6436751e-002
-6.5224711e-002
-6.3971590e-002
-6.2685781e-002
-6.1345517e-002
-5.9983748e-002
-5.8591568e-002
-5.7161645e-002
-5.5717365e-002
-5.4245277e-002
-5.2763075e-002
-5.1255616e-002
-4.9738576e-002
-4.8216572e-002
-4.6684303e-002
-4.5148841e-002
-4.3609754e-002
-4.2064909e-002
-4.0534917e-002
-3.9005368e-002
-3.7481285e-002
-3.5969756e-002
-3.4462095e-002
-3.2975408e-002
-3.1501761e-002
-3.0050266e-002
-2.8607217e-002
-2.7185943e-002
-2.5787585e-002
-2.4416099e-002
-2.3068017e-002
-2.1746755e-002
-2.0453179e-002
-1.9187243e-002
-1.7943338e-002
-1.6732471e-002
-1.5540555e-002
-1.4390467e-002
-1.3271822e-002
-1.2185000e-002
-1.1131555e-002
-1.0115022e-002
-9.1325330e-003
-8.1798233e-003
-7.2615817e-003
-6.3792293e-003
-5.5337211e-003
-4.7222596e-003
-3.9401124e-003
-3.1933778e-003
-2.4826724e-003
-1.8039473e-003
-1.1568136e-003
-5.4642809e-004
2.7604519e-005
5.8322642e-004
1.0902329e-003
1.5784683e-003
2.0274176e-003
2.4508540e-003
2.8446758e-003
3.2091886e-003
3.5401247e-003
3.8456408e-003
4.1251642e-003
4.3801862e-003
4.6039530e-003
4.8109469e-003
4.9839688e-003
5.1382275e-003
5.2715759e-003
5.3838976e-003
5.4753783e-003
5.5404364e-003
5.5917129e-003
5.6266114e-003
5.6389200e-003
5.6455197e-003
5.6220643e-003
5.5938023e-003
5.5475715e-003
5.4876040e-003
5.4196776e-003
5.3471681e-003
5.2461166e-003
5.1407354e-003
5.0393023e-003
4.9137604e-003
4.7932561e-003
4.6606461e-003
4.5209853e-003
4.3730720e-003
4.2264269e-003
4.0819753e-003
3.9207432e-003
3.7603923e-003
3.6008268e-003
3.4418874e-003
3.2739613e-003
3.1125421e-003
2.9469448e-003
2.7870464e-003
2.6201759e-003
2.4625617e-003
2.3017255e-003
2.1461584e-003
1.9841141e-003
1.8348265e-003
1.6868083e-003
1.5443220e-003
1.3902495e-003
1.2577885e-003
1.1250155e-003
9.8859883e-004
8.6084433e-004
7.4580259e-004
6.2393761e-004
5.1073885e-004
4.0265402e-004
2.9495311e-004
2.0430171e-004
1.0943831e-004
1.3494974e-005
-6.1733441e-005
-1.4463809e-004
-2.0983373e-004
-2.8969812e-004
-3.5011759e-004
-4.0951215e-004
-4.6063255e-004
-5.1455722e-004
-5.5645764e-004
-5.9461189e-004
-6.3415949e-004
-6.6504151e-004
-6.9179375e-004
-7.2153920e-004
-7.3193572e-004
-7.5300014e-004
-7.6307936e-004
-7.7579773e-004
-7.8014496e-004
-7.8036647e-004
-7.7798695e-004
-7.8343323e-004
-7.7248486e-004
-7.6813719e-004
-7.4905981e-004
-7.4409419e-004
-7.2550431e-004
-7.1577365e-004
-6.9416146e-004
-6.7776908e-004
-6.5403334e-004
-6.3124935e-004
-6.1327474e-004
-5.8709305e-004
-5.6778026e-004
-5.4665656e-004
-5.2265643e-004
-5.0407143e-004
-4.8937912e-004
-4.8752280e-004
-4.9475181e-004
-5.6176926e-004
-5.5252865e-004
+102
View File
@@ -0,0 +1,102 @@
import torch
from torch.nn import functional as F
class Stretch2d(torch.nn.Module):
def __init__(self, x_scale, y_scale, mode="nearest"):
super().__init__()
self.x_scale = x_scale
self.y_scale = y_scale
self.mode = mode
def forward(self, x):
"""
x (Tensor): Input tensor (B, C, F, T).
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
"""
return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
class UpsampleNetwork(torch.nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
upsample_factors,
nonlinear_activation=None,
nonlinear_activation_params={},
interpolate_mode="nearest",
freq_axis_kernel_size=1,
use_causal_conv=False,
):
super().__init__()
self.use_causal_conv = use_causal_conv
self.up_layers = torch.nn.ModuleList()
for scale in upsample_factors:
# interpolation layer
stretch = Stretch2d(scale, 1, interpolate_mode)
self.up_layers += [stretch]
# conv layer
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
if use_causal_conv:
padding = (freq_axis_padding, scale * 2)
else:
padding = (freq_axis_padding, scale)
conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
self.up_layers += [conv]
# nonlinear
if nonlinear_activation is not None:
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
self.up_layers += [nonlinear]
def forward(self, c):
"""
c : (B, C, T_in).
Tensor: (B, C, T_upsample)
"""
c = c.unsqueeze(1) # (B, 1, C, T)
for f in self.up_layers:
c = f(c)
return c.squeeze(1) # (B, C, T')
class ConvUpsample(torch.nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
upsample_factors,
nonlinear_activation=None,
nonlinear_activation_params={},
interpolate_mode="nearest",
freq_axis_kernel_size=1,
aux_channels=80,
aux_context_window=0,
use_causal_conv=False,
):
super().__init__()
self.aux_context_window = aux_context_window
self.use_causal_conv = use_causal_conv and aux_context_window > 0
# To capture wide-context information in conditional features
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
self.upsample = UpsampleNetwork(
upsample_factors=upsample_factors,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
interpolate_mode=interpolate_mode,
freq_axis_kernel_size=freq_axis_kernel_size,
use_causal_conv=use_causal_conv,
)
def forward(self, c):
"""
c : (B, C, T_in).
Tensor: (B, C, T_upsampled),
"""
c_ = self.conv_in(c)
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
return self.upsample(c)
+166
View File
@@ -0,0 +1,166 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
class Conv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
nn.init.orthogonal_(self.weight)
nn.init.zeros_(self.bias)
class PositionalEncoding(nn.Module):
"""Positional encoding with noise level conditioning"""
def __init__(self, n_channels, max_len=10000):
super().__init__()
self.n_channels = n_channels
self.max_len = max_len
self.C = 5000
self.pe = torch.zeros(0, 0)
def forward(self, x, noise_level):
if x.shape[2] > self.pe.shape[1]:
self.init_pe_matrix(x.shape[1], x.shape[2], x)
return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C
def init_pe_matrix(self, n_channels, max_len, x):
pe = torch.zeros(max_len, n_channels)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
pe[:, 0::2] = torch.sin(position / div_term)
pe[:, 1::2] = torch.cos(position / div_term)
self.pe = pe.transpose(0, 1).to(x)
class FiLM(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.encoding = PositionalEncoding(input_size)
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
nn.init.xavier_uniform_(self.input_conv.weight)
nn.init.xavier_uniform_(self.output_conv.weight)
nn.init.zeros_(self.input_conv.bias)
nn.init.zeros_(self.output_conv.bias)
def forward(self, x, noise_scale):
o = self.input_conv(x)
o = F.leaky_relu(o, 0.2)
o = self.encoding(o, noise_scale)
shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
return shift, scale
def remove_weight_norm(self):
remove_parametrizations(self.input_conv, "weight")
remove_parametrizations(self.output_conv, "weight")
def apply_weight_norm(self):
self.input_conv = weight_norm(self.input_conv)
self.output_conv = weight_norm(self.output_conv)
@torch.jit.script
def shif_and_scale(x, scale, shift):
o = shift + scale * x
return o
class UBlock(nn.Module):
def __init__(self, input_size, hidden_size, factor, dilation):
super().__init__()
assert isinstance(dilation, (list, tuple))
assert len(dilation) == 4
self.factor = factor
self.res_block = Conv1d(input_size, hidden_size, 1)
self.main_block = nn.ModuleList(
[
Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]),
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]),
]
)
self.out_block = nn.ModuleList(
[
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]),
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]),
]
)
def forward(self, x, shift, scale):
x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
res = self.res_block(x_inter)
o = F.leaky_relu(x_inter, 0.2)
o = F.interpolate(o, size=x.shape[-1] * self.factor)
o = self.main_block[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block[1](o)
res2 = res + o
o = shif_and_scale(res2, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.out_block[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.out_block[1](o)
o = o + res2
return o
def remove_weight_norm(self):
remove_parametrizations(self.res_block, "weight")
for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
remove_parametrizations(layer, "weight")
for _, layer in enumerate(self.out_block):
if len(layer.state_dict()) != 0:
remove_parametrizations(layer, "weight")
def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block)
for idx, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
self.main_block[idx] = weight_norm(layer)
for idx, layer in enumerate(self.out_block):
if len(layer.state_dict()) != 0:
self.out_block[idx] = weight_norm(layer)
class DBlock(nn.Module):
def __init__(self, input_size, hidden_size, factor):
super().__init__()
self.factor = factor
self.res_block = Conv1d(input_size, hidden_size, 1)
self.main_block = nn.ModuleList(
[
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
]
)
def forward(self, x):
size = x.shape[-1] // self.factor
res = self.res_block(x)
res = F.interpolate(res, size=size)
o = F.interpolate(x, size=size)
for layer in self.main_block:
o = F.leaky_relu(o, 0.2)
o = layer(o)
return o + res
def remove_weight_norm(self):
remove_parametrizations(self.res_block, "weight")
for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
remove_parametrizations(layer, "weight")
def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block)
for idx, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
self.main_block[idx] = weight_norm(layer)