Add files via upload
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,56 @@
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
class ResStack(nn.Module):
|
||||
def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]):
|
||||
super().__init__()
|
||||
resstack = []
|
||||
for dilation in dilations:
|
||||
resstack += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(dilation),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(padding),
|
||||
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
||||
]
|
||||
self.resstack = nn.Sequential(*resstack)
|
||||
|
||||
self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.shortcut(x)
|
||||
x2 = self.resstack(x)
|
||||
return x1 + x2
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.shortcut, "weight")
|
||||
remove_parametrizations(self.resstack[2], "weight")
|
||||
remove_parametrizations(self.resstack[5], "weight")
|
||||
remove_parametrizations(self.resstack[8], "weight")
|
||||
remove_parametrizations(self.resstack[11], "weight")
|
||||
remove_parametrizations(self.resstack[14], "weight")
|
||||
remove_parametrizations(self.resstack[17], "weight")
|
||||
|
||||
|
||||
class MRF(nn.Module):
|
||||
def __init__(self, kernels, channel, dilations=[1, 3, 5]): # # pylint: disable=dangerous-default-value
|
||||
super().__init__()
|
||||
self.resblock1 = ResStack(kernels[0], channel, 0, dilations)
|
||||
self.resblock2 = ResStack(kernels[1], channel, 6, dilations)
|
||||
self.resblock3 = ResStack(kernels[2], channel, 12, dilations)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.resblock1(x)
|
||||
x2 = self.resblock2(x)
|
||||
x3 = self.resblock3(x)
|
||||
return x1 + x2 + x3
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.resblock1.remove_weight_norm()
|
||||
self.resblock2.remove_weight_norm()
|
||||
self.resblock3.remove_weight_norm()
|
||||
@@ -0,0 +1,368 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
|
||||
#################################
|
||||
# GENERATOR LOSSES
|
||||
#################################
|
||||
|
||||
|
||||
class STFTLoss(nn.Module):
|
||||
"""STFT loss. Input generate and real waveforms are converted
|
||||
to spectrograms compared with L1 and Spectral convergence losses.
|
||||
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||
|
||||
def __init__(self, n_fft, hop_length, win_length):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.stft = TorchSTFT(n_fft, hop_length, win_length)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
y_hat_M = self.stft(y_hat)
|
||||
y_M = self.stft(y)
|
||||
# magnitude loss
|
||||
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
|
||||
# spectral convergence loss
|
||||
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||
"""Multi-scale STFT loss. Input generate and real waveforms are converted
|
||||
to spectrograms compared with L1 and Spectral convergence losses.
|
||||
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||
|
||||
def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)):
|
||||
super().__init__()
|
||||
self.loss_funcs = torch.nn.ModuleList()
|
||||
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
|
||||
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
N = len(self.loss_funcs)
|
||||
loss_sc = 0
|
||||
loss_mag = 0
|
||||
for f in self.loss_funcs:
|
||||
lm, lsc = f(y_hat, y)
|
||||
loss_mag += lm
|
||||
loss_sc += lsc
|
||||
loss_sc /= N
|
||||
loss_mag /= N
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class L1SpecLoss(nn.Module):
|
||||
"""L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""
|
||||
|
||||
def __init__(
|
||||
self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True
|
||||
):
|
||||
super().__init__()
|
||||
self.use_mel = use_mel
|
||||
self.stft = TorchSTFT(
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
sample_rate=sample_rate,
|
||||
mel_fmin=mel_fmin,
|
||||
mel_fmax=mel_fmax,
|
||||
n_mels=n_mels,
|
||||
use_mel=use_mel,
|
||||
)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
y_hat_M = self.stft(y_hat)
|
||||
y_M = self.stft(y)
|
||||
# magnitude loss
|
||||
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
|
||||
return loss_mag
|
||||
|
||||
|
||||
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||
"""Multiscale STFT loss for multi band model outputs.
|
||||
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, y_hat, y):
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y = y.view(-1, 1, y.shape[2])
|
||||
return super().forward(y_hat.squeeze(1), y.squeeze(1))
|
||||
|
||||
|
||||
class MSEGLoss(nn.Module):
|
||||
"""Mean Squared Generator Loss"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_real):
|
||||
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
|
||||
return loss_fake
|
||||
|
||||
|
||||
class HingeGLoss(nn.Module):
|
||||
"""Hinge Discriminator Loss"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_real):
|
||||
# TODO: this might be wrong
|
||||
loss_fake = torch.mean(F.relu(1.0 - score_real))
|
||||
return loss_fake
|
||||
|
||||
|
||||
##################################
|
||||
# DISCRIMINATOR LOSSES
|
||||
##################################
|
||||
|
||||
|
||||
class MSEDLoss(nn.Module):
|
||||
"""Mean Squared Discriminator Loss"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_func = nn.MSELoss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape))
|
||||
loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
||||
class HingeDLoss(nn.Module):
|
||||
"""Hinge Discriminator Loss"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(F.relu(1.0 - score_real))
|
||||
loss_fake = torch.mean(F.relu(1.0 + score_fake))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
||||
class MelganFeatureLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_func = nn.L1Loss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, fake_feats, real_feats):
|
||||
loss_feats = 0
|
||||
num_feats = 0
|
||||
for idx, _ in enumerate(fake_feats):
|
||||
for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]):
|
||||
loss_feats += self.loss_func(fake_feat, real_feat)
|
||||
num_feats += 1
|
||||
loss_feats = loss_feats / num_feats
|
||||
return loss_feats
|
||||
|
||||
|
||||
#####################################
|
||||
# LOSS WRAPPERS
|
||||
#####################################
|
||||
|
||||
|
||||
def _apply_G_adv_loss(scores_fake, loss_func):
|
||||
"""Compute G adversarial loss function
|
||||
and normalize values"""
|
||||
adv_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
for score_fake in scores_fake:
|
||||
fake_loss = loss_func(score_fake)
|
||||
adv_loss += fake_loss
|
||||
adv_loss /= len(scores_fake)
|
||||
else:
|
||||
fake_loss = loss_func(scores_fake)
|
||||
adv_loss = fake_loss
|
||||
return adv_loss
|
||||
|
||||
|
||||
def _apply_D_loss(scores_fake, scores_real, loss_func):
|
||||
"""Compute D loss func and normalize loss values"""
|
||||
loss = 0
|
||||
real_loss = 0
|
||||
fake_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
# multi-scale loss
|
||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
||||
total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real)
|
||||
loss += total_loss
|
||||
real_loss += real_loss_
|
||||
fake_loss += fake_loss_
|
||||
# normalize loss values with number of scales (discriminators)
|
||||
loss /= len(scores_fake)
|
||||
real_loss /= len(scores_real)
|
||||
fake_loss /= len(scores_fake)
|
||||
else:
|
||||
# single scale loss
|
||||
total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real)
|
||||
loss = total_loss
|
||||
return loss, real_loss, fake_loss
|
||||
|
||||
|
||||
##################################
|
||||
# MODEL LOSSES
|
||||
##################################
|
||||
|
||||
|
||||
class GeneratorLoss(nn.Module):
|
||||
"""Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes
|
||||
losses. It allows to experiment with different combinations of loss functions with different models by just
|
||||
changing configurations.
|
||||
|
||||
Args:
|
||||
C (AttrDict): model configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, C):
|
||||
super().__init__()
|
||||
assert not (
|
||||
C.use_mse_gan_loss and C.use_hinge_gan_loss
|
||||
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False
|
||||
self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False
|
||||
self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False
|
||||
self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False
|
||||
|
||||
self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0
|
||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0
|
||||
self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0
|
||||
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0
|
||||
self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0
|
||||
self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0
|
||||
|
||||
if C.use_stft_loss:
|
||||
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
|
||||
if C.use_subband_stft_loss:
|
||||
self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params)
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEGLoss()
|
||||
if C.use_hinge_gan_loss:
|
||||
self.hinge_loss = HingeGLoss()
|
||||
if C.use_feat_match_loss:
|
||||
self.feat_match_loss = MelganFeatureLoss()
|
||||
if C.use_l1_spec_loss:
|
||||
assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"]
|
||||
self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)
|
||||
|
||||
def forward(
|
||||
self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None
|
||||
):
|
||||
gen_loss = 0
|
||||
adv_loss = 0
|
||||
return_dict = {}
|
||||
|
||||
# STFT Loss
|
||||
if self.use_stft_loss:
|
||||
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1))
|
||||
return_dict["G_stft_loss_mg"] = stft_loss_mg
|
||||
return_dict["G_stft_loss_sc"] = stft_loss_sc
|
||||
gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
||||
|
||||
# L1 Spec loss
|
||||
if self.use_l1_spec_loss:
|
||||
l1_spec_loss = self.l1_spec_loss(y_hat, y)
|
||||
return_dict["G_l1_spec_loss"] = l1_spec_loss
|
||||
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
|
||||
|
||||
# subband STFT Loss
|
||||
if self.use_subband_stft_loss:
|
||||
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
|
||||
return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg
|
||||
return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc
|
||||
gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
|
||||
|
||||
# multiscale MSE adversarial loss
|
||||
if self.use_mse_gan_loss and scores_fake is not None:
|
||||
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
|
||||
return_dict["G_mse_fake_loss"] = mse_fake_loss
|
||||
adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss
|
||||
|
||||
# multiscale Hinge adversarial loss
|
||||
if self.use_hinge_gan_loss and not scores_fake is not None:
|
||||
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
|
||||
return_dict["G_hinge_fake_loss"] = hinge_fake_loss
|
||||
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
|
||||
|
||||
# Feature Matching Loss
|
||||
if self.use_feat_match_loss and not feats_fake is None:
|
||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||
return_dict["G_feat_match_loss"] = feat_match_loss
|
||||
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
||||
return_dict["loss"] = gen_loss + adv_loss
|
||||
return_dict["G_gen_loss"] = gen_loss
|
||||
return_dict["G_adv_loss"] = adv_loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
"""Like ```GeneratorLoss```"""
|
||||
|
||||
def __init__(self, C):
|
||||
super().__init__()
|
||||
assert not (
|
||||
C.use_mse_gan_loss and C.use_hinge_gan_loss
|
||||
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEDLoss()
|
||||
if C.use_hinge_gan_loss:
|
||||
self.hinge_loss = HingeDLoss()
|
||||
|
||||
def forward(self, scores_fake, scores_real):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
||||
if self.use_mse_gan_loss:
|
||||
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(
|
||||
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss
|
||||
)
|
||||
return_dict["D_mse_gan_loss"] = mse_D_loss
|
||||
return_dict["D_mse_gan_real_loss"] = mse_D_real_loss
|
||||
return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss
|
||||
loss += mse_D_loss
|
||||
|
||||
if self.use_hinge_gan_loss:
|
||||
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(
|
||||
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss
|
||||
)
|
||||
return_dict["D_hinge_gan_loss"] = hinge_D_loss
|
||||
return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss
|
||||
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
|
||||
loss += hinge_D_loss
|
||||
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class WaveRNNLoss(nn.Module):
|
||||
def __init__(self, wave_rnn_mode: Union[str, int]):
|
||||
super().__init__()
|
||||
if wave_rnn_mode == "mold":
|
||||
self.loss_func = discretized_mix_logistic_loss
|
||||
elif wave_rnn_mode == "gauss":
|
||||
self.loss_func = gaussian_loss
|
||||
elif isinstance(wave_rnn_mode, int):
|
||||
self.loss_func = torch.nn.CrossEntropyLoss()
|
||||
else:
|
||||
raise ValueError(" [!] Unknown mode for Wavernn.")
|
||||
|
||||
def forward(self, y_hat, y) -> Dict:
|
||||
loss = self.loss_func(y_hat, y)
|
||||
return {"loss": loss}
|
||||
@@ -0,0 +1,198 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class KernelPredictor(torch.nn.Module):
|
||||
"""Kernel predictor for the location-variable convolutions"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int):
|
||||
kpnet_
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
|
||||
l_b = conv_out_channels * conv_layers
|
||||
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
self.input_conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_conv = torch.nn.Sequential(
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True)
|
||||
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
Returns:
|
||||
"""
|
||||
batch, _, cond_length = c.shape
|
||||
|
||||
c = self.input_conv(c)
|
||||
c = c + self.residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
|
||||
kernels = k.contiguous().view(
|
||||
batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length
|
||||
)
|
||||
bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length)
|
||||
return kernels, bias
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
"""the location-variable convolutions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
upsample_ratio,
|
||||
conv_layers=4,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = conv_layers
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.convs = torch.nn.ModuleList()
|
||||
|
||||
self.upsample = torch.nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=upsample_ratio * 2,
|
||||
stride=upsample_ratio,
|
||||
padding=upsample_ratio // 2 + upsample_ratio % 2,
|
||||
output_padding=upsample_ratio % 2,
|
||||
)
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=conv_layers,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout,
|
||||
)
|
||||
|
||||
for i in range(conv_layers):
|
||||
padding = (3**i) * int((conv_kernel_size - 1) / 2)
|
||||
conv = torch.nn.Conv1d(
|
||||
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i
|
||||
)
|
||||
|
||||
self.convs.append(conv)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
"""
|
||||
in_channels = x.shape[1]
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = self.upsample(x)
|
||||
|
||||
for i in range(self.conv_layers):
|
||||
y = F.leaky_relu(x, 0.2)
|
||||
y = self.convs[i](y)
|
||||
y = F.leaky_relu(y, 0.2)
|
||||
|
||||
k = kernels[:, i, :, :, :, :]
|
||||
b = bias[:, i, :, :]
|
||||
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
|
||||
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def location_variable_convolution(x, kernel, bias, dilation, hop_size):
|
||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
|
||||
assert in_length == (
|
||||
kernel_length * hop_size
|
||||
), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
x = x.unfold(
|
||||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o + bias.unsqueeze(-1).unsqueeze(-1)
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
return o
|
||||
@@ -0,0 +1,43 @@
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
class ResidualStack(nn.Module):
|
||||
def __init__(self, channels, num_res_blocks, kernel_size):
|
||||
super().__init__()
|
||||
|
||||
assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
|
||||
base_padding = (kernel_size - 1) // 2
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for idx in range(num_res_blocks):
|
||||
layer_kernel_size = kernel_size
|
||||
layer_dilation = layer_kernel_size**idx
|
||||
layer_padding = base_padding * layer_dilation
|
||||
self.blocks += [
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(
|
||||
nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True)
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
|
||||
)
|
||||
]
|
||||
|
||||
self.shortcuts = nn.ModuleList(
|
||||
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
x = shortcut(x) + block(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
remove_parametrizations(block[2], "weight")
|
||||
remove_parametrizations(block[4], "weight")
|
||||
remove_parametrizations(shortcut, "weight")
|
||||
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class ResidualBlock(torch.nn.Module):
|
||||
"""Residual block module in WaveNet."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
dropout=0.0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
# no future time stamps available
|
||||
if use_causal_conv:
|
||||
padding = (kernel_size - 1) * dilation
|
||||
else:
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.use_causal_conv = use_causal_conv
|
||||
|
||||
# dilation conv
|
||||
self.conv = torch.nn.Conv1d(
|
||||
res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias
|
||||
)
|
||||
|
||||
# local conditioning
|
||||
if aux_channels > 0:
|
||||
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
|
||||
else:
|
||||
self.conv1x1_aux = None
|
||||
|
||||
# conv output is split into two groups
|
||||
gate_out_channels = gate_channels // 2
|
||||
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
|
||||
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""
|
||||
x: B x D_res x T
|
||||
c: B x D_aux x T
|
||||
"""
|
||||
residual = x
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.conv(x)
|
||||
|
||||
# remove future time steps if use_causal_conv conv
|
||||
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
|
||||
|
||||
# split into two part for gated activation
|
||||
splitdim = 1
|
||||
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
||||
|
||||
# local conditioning
|
||||
if c is not None:
|
||||
assert self.conv1x1_aux is not None
|
||||
c = self.conv1x1_aux(c)
|
||||
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
||||
xa, xb = xa + ca, xb + cb
|
||||
|
||||
x = torch.tanh(xa) * torch.sigmoid(xb)
|
||||
|
||||
# for skip connection
|
||||
s = self.conv1x1_skip(x)
|
||||
|
||||
# for residual connection
|
||||
x = (self.conv1x1_out(x) + residual) * (0.5**2)
|
||||
|
||||
return x, s
|
||||
@@ -0,0 +1,53 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy import signal as sig
|
||||
|
||||
|
||||
# adapted from
|
||||
# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan
|
||||
class PQMF(torch.nn.Module):
|
||||
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
|
||||
super().__init__()
|
||||
|
||||
self.N = N
|
||||
self.taps = taps
|
||||
self.cutoff = cutoff
|
||||
self.beta = beta
|
||||
|
||||
QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta))
|
||||
H = np.zeros((N, len(QMF)))
|
||||
G = np.zeros((N, len(QMF)))
|
||||
for k in range(N):
|
||||
constant_factor = (
|
||||
(2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2))
|
||||
) # TODO: (taps - 1) -> taps
|
||||
phase = (-1) ** k * np.pi / 4
|
||||
H[k] = 2 * QMF * np.cos(constant_factor + phase)
|
||||
|
||||
G[k] = 2 * QMF * np.cos(constant_factor - phase)
|
||||
|
||||
H = torch.from_numpy(H[:, None, :]).float()
|
||||
G = torch.from_numpy(G[None, :, :]).float()
|
||||
|
||||
self.register_buffer("H", H)
|
||||
self.register_buffer("G", G)
|
||||
|
||||
updown_filter = torch.zeros((N, N, N)).float()
|
||||
for k in range(N):
|
||||
updown_filter[k, k, 0] = 1.0
|
||||
self.register_buffer("updown_filter", updown_filter)
|
||||
self.N = N
|
||||
|
||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.analysis(x)
|
||||
|
||||
def analysis(self, x):
|
||||
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
|
||||
|
||||
def synthesis(self, x):
|
||||
x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N)
|
||||
x = F.conv1d(x, self.G, padding=self.taps // 2)
|
||||
return x
|
||||
@@ -0,0 +1,640 @@
|
||||
0.0000000e+000
|
||||
-5.5252865e-004
|
||||
-5.6176926e-004
|
||||
-4.9475181e-004
|
||||
-4.8752280e-004
|
||||
-4.8937912e-004
|
||||
-5.0407143e-004
|
||||
-5.2265643e-004
|
||||
-5.4665656e-004
|
||||
-5.6778026e-004
|
||||
-5.8709305e-004
|
||||
-6.1327474e-004
|
||||
-6.3124935e-004
|
||||
-6.5403334e-004
|
||||
-6.7776908e-004
|
||||
-6.9416146e-004
|
||||
-7.1577365e-004
|
||||
-7.2550431e-004
|
||||
-7.4409419e-004
|
||||
-7.4905981e-004
|
||||
-7.6813719e-004
|
||||
-7.7248486e-004
|
||||
-7.8343323e-004
|
||||
-7.7798695e-004
|
||||
-7.8036647e-004
|
||||
-7.8014496e-004
|
||||
-7.7579773e-004
|
||||
-7.6307936e-004
|
||||
-7.5300014e-004
|
||||
-7.3193572e-004
|
||||
-7.2153920e-004
|
||||
-6.9179375e-004
|
||||
-6.6504151e-004
|
||||
-6.3415949e-004
|
||||
-5.9461189e-004
|
||||
-5.5645764e-004
|
||||
-5.1455722e-004
|
||||
-4.6063255e-004
|
||||
-4.0951215e-004
|
||||
-3.5011759e-004
|
||||
-2.8969812e-004
|
||||
-2.0983373e-004
|
||||
-1.4463809e-004
|
||||
-6.1733441e-005
|
||||
1.3494974e-005
|
||||
1.0943831e-004
|
||||
2.0430171e-004
|
||||
2.9495311e-004
|
||||
4.0265402e-004
|
||||
5.1073885e-004
|
||||
6.2393761e-004
|
||||
7.4580259e-004
|
||||
8.6084433e-004
|
||||
9.8859883e-004
|
||||
1.1250155e-003
|
||||
1.2577885e-003
|
||||
1.3902495e-003
|
||||
1.5443220e-003
|
||||
1.6868083e-003
|
||||
1.8348265e-003
|
||||
1.9841141e-003
|
||||
2.1461584e-003
|
||||
2.3017255e-003
|
||||
2.4625617e-003
|
||||
2.6201759e-003
|
||||
2.7870464e-003
|
||||
2.9469448e-003
|
||||
3.1125421e-003
|
||||
3.2739613e-003
|
||||
3.4418874e-003
|
||||
3.6008268e-003
|
||||
3.7603923e-003
|
||||
3.9207432e-003
|
||||
4.0819753e-003
|
||||
4.2264269e-003
|
||||
4.3730720e-003
|
||||
4.5209853e-003
|
||||
4.6606461e-003
|
||||
4.7932561e-003
|
||||
4.9137604e-003
|
||||
5.0393023e-003
|
||||
5.1407354e-003
|
||||
5.2461166e-003
|
||||
5.3471681e-003
|
||||
5.4196776e-003
|
||||
5.4876040e-003
|
||||
5.5475715e-003
|
||||
5.5938023e-003
|
||||
5.6220643e-003
|
||||
5.6455197e-003
|
||||
5.6389200e-003
|
||||
5.6266114e-003
|
||||
5.5917129e-003
|
||||
5.5404364e-003
|
||||
5.4753783e-003
|
||||
5.3838976e-003
|
||||
5.2715759e-003
|
||||
5.1382275e-003
|
||||
4.9839688e-003
|
||||
4.8109469e-003
|
||||
4.6039530e-003
|
||||
4.3801862e-003
|
||||
4.1251642e-003
|
||||
3.8456408e-003
|
||||
3.5401247e-003
|
||||
3.2091886e-003
|
||||
2.8446758e-003
|
||||
2.4508540e-003
|
||||
2.0274176e-003
|
||||
1.5784683e-003
|
||||
1.0902329e-003
|
||||
5.8322642e-004
|
||||
2.7604519e-005
|
||||
-5.4642809e-004
|
||||
-1.1568136e-003
|
||||
-1.8039473e-003
|
||||
-2.4826724e-003
|
||||
-3.1933778e-003
|
||||
-3.9401124e-003
|
||||
-4.7222596e-003
|
||||
-5.5337211e-003
|
||||
-6.3792293e-003
|
||||
-7.2615817e-003
|
||||
-8.1798233e-003
|
||||
-9.1325330e-003
|
||||
-1.0115022e-002
|
||||
-1.1131555e-002
|
||||
-1.2185000e-002
|
||||
-1.3271822e-002
|
||||
-1.4390467e-002
|
||||
-1.5540555e-002
|
||||
-1.6732471e-002
|
||||
-1.7943338e-002
|
||||
-1.9187243e-002
|
||||
-2.0453179e-002
|
||||
-2.1746755e-002
|
||||
-2.3068017e-002
|
||||
-2.4416099e-002
|
||||
-2.5787585e-002
|
||||
-2.7185943e-002
|
||||
-2.8607217e-002
|
||||
-3.0050266e-002
|
||||
-3.1501761e-002
|
||||
-3.2975408e-002
|
||||
-3.4462095e-002
|
||||
-3.5969756e-002
|
||||
-3.7481285e-002
|
||||
-3.9005368e-002
|
||||
-4.0534917e-002
|
||||
-4.2064909e-002
|
||||
-4.3609754e-002
|
||||
-4.5148841e-002
|
||||
-4.6684303e-002
|
||||
-4.8216572e-002
|
||||
-4.9738576e-002
|
||||
-5.1255616e-002
|
||||
-5.2763075e-002
|
||||
-5.4245277e-002
|
||||
-5.5717365e-002
|
||||
-5.7161645e-002
|
||||
-5.8591568e-002
|
||||
-5.9983748e-002
|
||||
-6.1345517e-002
|
||||
-6.2685781e-002
|
||||
-6.3971590e-002
|
||||
-6.5224711e-002
|
||||
-6.6436751e-002
|
||||
-6.7607599e-002
|
||||
-6.8704383e-002
|
||||
-6.9763024e-002
|
||||
-7.0762871e-002
|
||||
-7.1700267e-002
|
||||
-7.2568258e-002
|
||||
-7.3362026e-002
|
||||
-7.4100364e-002
|
||||
-7.4745256e-002
|
||||
-7.5313734e-002
|
||||
-7.5800836e-002
|
||||
-7.6199248e-002
|
||||
-7.6499217e-002
|
||||
-7.6709349e-002
|
||||
-7.6817398e-002
|
||||
-7.6823001e-002
|
||||
-7.6720492e-002
|
||||
-7.6505072e-002
|
||||
-7.6174832e-002
|
||||
-7.5730576e-002
|
||||
-7.5157626e-002
|
||||
-7.4466439e-002
|
||||
-7.3640601e-002
|
||||
-7.2677464e-002
|
||||
-7.1582636e-002
|
||||
-7.0353307e-002
|
||||
-6.8966401e-002
|
||||
-6.7452502e-002
|
||||
-6.5769067e-002
|
||||
-6.3944481e-002
|
||||
-6.1960278e-002
|
||||
-5.9816657e-002
|
||||
-5.7515269e-002
|
||||
-5.5046003e-002
|
||||
-5.2409382e-002
|
||||
-4.9597868e-002
|
||||
-4.6630331e-002
|
||||
-4.3476878e-002
|
||||
-4.0145828e-002
|
||||
-3.6641812e-002
|
||||
-3.2958393e-002
|
||||
-2.9082401e-002
|
||||
-2.5030756e-002
|
||||
-2.0799707e-002
|
||||
-1.6370126e-002
|
||||
-1.1762383e-002
|
||||
-6.9636862e-003
|
||||
-1.9765601e-003
|
||||
3.2086897e-003
|
||||
8.5711749e-003
|
||||
1.4128883e-002
|
||||
1.9883413e-002
|
||||
2.5822729e-002
|
||||
3.1953127e-002
|
||||
3.8277657e-002
|
||||
4.4780682e-002
|
||||
5.1480418e-002
|
||||
5.8370533e-002
|
||||
6.5440985e-002
|
||||
7.2694330e-002
|
||||
8.0137293e-002
|
||||
8.7754754e-002
|
||||
9.5553335e-002
|
||||
1.0353295e-001
|
||||
1.1168269e-001
|
||||
1.2000780e-001
|
||||
1.2850029e-001
|
||||
1.3715518e-001
|
||||
1.4597665e-001
|
||||
1.5496071e-001
|
||||
1.6409589e-001
|
||||
1.7338082e-001
|
||||
1.8281725e-001
|
||||
1.9239667e-001
|
||||
2.0212502e-001
|
||||
2.1197359e-001
|
||||
2.2196527e-001
|
||||
2.3206909e-001
|
||||
2.4230169e-001
|
||||
2.5264803e-001
|
||||
2.6310533e-001
|
||||
2.7366340e-001
|
||||
2.8432142e-001
|
||||
2.9507167e-001
|
||||
3.0590986e-001
|
||||
3.1682789e-001
|
||||
3.2781137e-001
|
||||
3.3887227e-001
|
||||
3.4999141e-001
|
||||
3.6115899e-001
|
||||
3.7237955e-001
|
||||
3.8363500e-001
|
||||
3.9492118e-001
|
||||
4.0623177e-001
|
||||
4.1756969e-001
|
||||
4.2891199e-001
|
||||
4.4025538e-001
|
||||
4.5159965e-001
|
||||
4.6293081e-001
|
||||
4.7424532e-001
|
||||
4.8552531e-001
|
||||
4.9677083e-001
|
||||
5.0798175e-001
|
||||
5.1912350e-001
|
||||
5.3022409e-001
|
||||
5.4125534e-001
|
||||
5.5220513e-001
|
||||
5.6307891e-001
|
||||
5.7385241e-001
|
||||
5.8454032e-001
|
||||
5.9511231e-001
|
||||
6.0557835e-001
|
||||
6.1591099e-001
|
||||
6.2612427e-001
|
||||
6.3619801e-001
|
||||
6.4612697e-001
|
||||
6.5590163e-001
|
||||
6.6551399e-001
|
||||
6.7496632e-001
|
||||
6.8423533e-001
|
||||
6.9332824e-001
|
||||
7.0223887e-001
|
||||
7.1094104e-001
|
||||
7.1944626e-001
|
||||
7.2774489e-001
|
||||
7.3582118e-001
|
||||
7.4368279e-001
|
||||
7.5131375e-001
|
||||
7.5870808e-001
|
||||
7.6586749e-001
|
||||
7.7277809e-001
|
||||
7.7942875e-001
|
||||
7.8583531e-001
|
||||
7.9197358e-001
|
||||
7.9784664e-001
|
||||
8.0344858e-001
|
||||
8.0876950e-001
|
||||
8.1381913e-001
|
||||
8.1857760e-001
|
||||
8.2304199e-001
|
||||
8.2722753e-001
|
||||
8.3110385e-001
|
||||
8.3469374e-001
|
||||
8.3797173e-001
|
||||
8.4095414e-001
|
||||
8.4362383e-001
|
||||
8.4598185e-001
|
||||
8.4803158e-001
|
||||
8.4978052e-001
|
||||
8.5119715e-001
|
||||
8.5230470e-001
|
||||
8.5310209e-001
|
||||
8.5357206e-001
|
||||
8.5373856e-001
|
||||
8.5357206e-001
|
||||
8.5310209e-001
|
||||
8.5230470e-001
|
||||
8.5119715e-001
|
||||
8.4978052e-001
|
||||
8.4803158e-001
|
||||
8.4598185e-001
|
||||
8.4362383e-001
|
||||
8.4095414e-001
|
||||
8.3797173e-001
|
||||
8.3469374e-001
|
||||
8.3110385e-001
|
||||
8.2722753e-001
|
||||
8.2304199e-001
|
||||
8.1857760e-001
|
||||
8.1381913e-001
|
||||
8.0876950e-001
|
||||
8.0344858e-001
|
||||
7.9784664e-001
|
||||
7.9197358e-001
|
||||
7.8583531e-001
|
||||
7.7942875e-001
|
||||
7.7277809e-001
|
||||
7.6586749e-001
|
||||
7.5870808e-001
|
||||
7.5131375e-001
|
||||
7.4368279e-001
|
||||
7.3582118e-001
|
||||
7.2774489e-001
|
||||
7.1944626e-001
|
||||
7.1094104e-001
|
||||
7.0223887e-001
|
||||
6.9332824e-001
|
||||
6.8423533e-001
|
||||
6.7496632e-001
|
||||
6.6551399e-001
|
||||
6.5590163e-001
|
||||
6.4612697e-001
|
||||
6.3619801e-001
|
||||
6.2612427e-001
|
||||
6.1591099e-001
|
||||
6.0557835e-001
|
||||
5.9511231e-001
|
||||
5.8454032e-001
|
||||
5.7385241e-001
|
||||
5.6307891e-001
|
||||
5.5220513e-001
|
||||
5.4125534e-001
|
||||
5.3022409e-001
|
||||
5.1912350e-001
|
||||
5.0798175e-001
|
||||
4.9677083e-001
|
||||
4.8552531e-001
|
||||
4.7424532e-001
|
||||
4.6293081e-001
|
||||
4.5159965e-001
|
||||
4.4025538e-001
|
||||
4.2891199e-001
|
||||
4.1756969e-001
|
||||
4.0623177e-001
|
||||
3.9492118e-001
|
||||
3.8363500e-001
|
||||
3.7237955e-001
|
||||
3.6115899e-001
|
||||
3.4999141e-001
|
||||
3.3887227e-001
|
||||
3.2781137e-001
|
||||
3.1682789e-001
|
||||
3.0590986e-001
|
||||
2.9507167e-001
|
||||
2.8432142e-001
|
||||
2.7366340e-001
|
||||
2.6310533e-001
|
||||
2.5264803e-001
|
||||
2.4230169e-001
|
||||
2.3206909e-001
|
||||
2.2196527e-001
|
||||
2.1197359e-001
|
||||
2.0212502e-001
|
||||
1.9239667e-001
|
||||
1.8281725e-001
|
||||
1.7338082e-001
|
||||
1.6409589e-001
|
||||
1.5496071e-001
|
||||
1.4597665e-001
|
||||
1.3715518e-001
|
||||
1.2850029e-001
|
||||
1.2000780e-001
|
||||
1.1168269e-001
|
||||
1.0353295e-001
|
||||
9.5553335e-002
|
||||
8.7754754e-002
|
||||
8.0137293e-002
|
||||
7.2694330e-002
|
||||
6.5440985e-002
|
||||
5.8370533e-002
|
||||
5.1480418e-002
|
||||
4.4780682e-002
|
||||
3.8277657e-002
|
||||
3.1953127e-002
|
||||
2.5822729e-002
|
||||
1.9883413e-002
|
||||
1.4128883e-002
|
||||
8.5711749e-003
|
||||
3.2086897e-003
|
||||
-1.9765601e-003
|
||||
-6.9636862e-003
|
||||
-1.1762383e-002
|
||||
-1.6370126e-002
|
||||
-2.0799707e-002
|
||||
-2.5030756e-002
|
||||
-2.9082401e-002
|
||||
-3.2958393e-002
|
||||
-3.6641812e-002
|
||||
-4.0145828e-002
|
||||
-4.3476878e-002
|
||||
-4.6630331e-002
|
||||
-4.9597868e-002
|
||||
-5.2409382e-002
|
||||
-5.5046003e-002
|
||||
-5.7515269e-002
|
||||
-5.9816657e-002
|
||||
-6.1960278e-002
|
||||
-6.3944481e-002
|
||||
-6.5769067e-002
|
||||
-6.7452502e-002
|
||||
-6.8966401e-002
|
||||
-7.0353307e-002
|
||||
-7.1582636e-002
|
||||
-7.2677464e-002
|
||||
-7.3640601e-002
|
||||
-7.4466439e-002
|
||||
-7.5157626e-002
|
||||
-7.5730576e-002
|
||||
-7.6174832e-002
|
||||
-7.6505072e-002
|
||||
-7.6720492e-002
|
||||
-7.6823001e-002
|
||||
-7.6817398e-002
|
||||
-7.6709349e-002
|
||||
-7.6499217e-002
|
||||
-7.6199248e-002
|
||||
-7.5800836e-002
|
||||
-7.5313734e-002
|
||||
-7.4745256e-002
|
||||
-7.4100364e-002
|
||||
-7.3362026e-002
|
||||
-7.2568258e-002
|
||||
-7.1700267e-002
|
||||
-7.0762871e-002
|
||||
-6.9763024e-002
|
||||
-6.8704383e-002
|
||||
-6.7607599e-002
|
||||
-6.6436751e-002
|
||||
-6.5224711e-002
|
||||
-6.3971590e-002
|
||||
-6.2685781e-002
|
||||
-6.1345517e-002
|
||||
-5.9983748e-002
|
||||
-5.8591568e-002
|
||||
-5.7161645e-002
|
||||
-5.5717365e-002
|
||||
-5.4245277e-002
|
||||
-5.2763075e-002
|
||||
-5.1255616e-002
|
||||
-4.9738576e-002
|
||||
-4.8216572e-002
|
||||
-4.6684303e-002
|
||||
-4.5148841e-002
|
||||
-4.3609754e-002
|
||||
-4.2064909e-002
|
||||
-4.0534917e-002
|
||||
-3.9005368e-002
|
||||
-3.7481285e-002
|
||||
-3.5969756e-002
|
||||
-3.4462095e-002
|
||||
-3.2975408e-002
|
||||
-3.1501761e-002
|
||||
-3.0050266e-002
|
||||
-2.8607217e-002
|
||||
-2.7185943e-002
|
||||
-2.5787585e-002
|
||||
-2.4416099e-002
|
||||
-2.3068017e-002
|
||||
-2.1746755e-002
|
||||
-2.0453179e-002
|
||||
-1.9187243e-002
|
||||
-1.7943338e-002
|
||||
-1.6732471e-002
|
||||
-1.5540555e-002
|
||||
-1.4390467e-002
|
||||
-1.3271822e-002
|
||||
-1.2185000e-002
|
||||
-1.1131555e-002
|
||||
-1.0115022e-002
|
||||
-9.1325330e-003
|
||||
-8.1798233e-003
|
||||
-7.2615817e-003
|
||||
-6.3792293e-003
|
||||
-5.5337211e-003
|
||||
-4.7222596e-003
|
||||
-3.9401124e-003
|
||||
-3.1933778e-003
|
||||
-2.4826724e-003
|
||||
-1.8039473e-003
|
||||
-1.1568136e-003
|
||||
-5.4642809e-004
|
||||
2.7604519e-005
|
||||
5.8322642e-004
|
||||
1.0902329e-003
|
||||
1.5784683e-003
|
||||
2.0274176e-003
|
||||
2.4508540e-003
|
||||
2.8446758e-003
|
||||
3.2091886e-003
|
||||
3.5401247e-003
|
||||
3.8456408e-003
|
||||
4.1251642e-003
|
||||
4.3801862e-003
|
||||
4.6039530e-003
|
||||
4.8109469e-003
|
||||
4.9839688e-003
|
||||
5.1382275e-003
|
||||
5.2715759e-003
|
||||
5.3838976e-003
|
||||
5.4753783e-003
|
||||
5.5404364e-003
|
||||
5.5917129e-003
|
||||
5.6266114e-003
|
||||
5.6389200e-003
|
||||
5.6455197e-003
|
||||
5.6220643e-003
|
||||
5.5938023e-003
|
||||
5.5475715e-003
|
||||
5.4876040e-003
|
||||
5.4196776e-003
|
||||
5.3471681e-003
|
||||
5.2461166e-003
|
||||
5.1407354e-003
|
||||
5.0393023e-003
|
||||
4.9137604e-003
|
||||
4.7932561e-003
|
||||
4.6606461e-003
|
||||
4.5209853e-003
|
||||
4.3730720e-003
|
||||
4.2264269e-003
|
||||
4.0819753e-003
|
||||
3.9207432e-003
|
||||
3.7603923e-003
|
||||
3.6008268e-003
|
||||
3.4418874e-003
|
||||
3.2739613e-003
|
||||
3.1125421e-003
|
||||
2.9469448e-003
|
||||
2.7870464e-003
|
||||
2.6201759e-003
|
||||
2.4625617e-003
|
||||
2.3017255e-003
|
||||
2.1461584e-003
|
||||
1.9841141e-003
|
||||
1.8348265e-003
|
||||
1.6868083e-003
|
||||
1.5443220e-003
|
||||
1.3902495e-003
|
||||
1.2577885e-003
|
||||
1.1250155e-003
|
||||
9.8859883e-004
|
||||
8.6084433e-004
|
||||
7.4580259e-004
|
||||
6.2393761e-004
|
||||
5.1073885e-004
|
||||
4.0265402e-004
|
||||
2.9495311e-004
|
||||
2.0430171e-004
|
||||
1.0943831e-004
|
||||
1.3494974e-005
|
||||
-6.1733441e-005
|
||||
-1.4463809e-004
|
||||
-2.0983373e-004
|
||||
-2.8969812e-004
|
||||
-3.5011759e-004
|
||||
-4.0951215e-004
|
||||
-4.6063255e-004
|
||||
-5.1455722e-004
|
||||
-5.5645764e-004
|
||||
-5.9461189e-004
|
||||
-6.3415949e-004
|
||||
-6.6504151e-004
|
||||
-6.9179375e-004
|
||||
-7.2153920e-004
|
||||
-7.3193572e-004
|
||||
-7.5300014e-004
|
||||
-7.6307936e-004
|
||||
-7.7579773e-004
|
||||
-7.8014496e-004
|
||||
-7.8036647e-004
|
||||
-7.7798695e-004
|
||||
-7.8343323e-004
|
||||
-7.7248486e-004
|
||||
-7.6813719e-004
|
||||
-7.4905981e-004
|
||||
-7.4409419e-004
|
||||
-7.2550431e-004
|
||||
-7.1577365e-004
|
||||
-6.9416146e-004
|
||||
-6.7776908e-004
|
||||
-6.5403334e-004
|
||||
-6.3124935e-004
|
||||
-6.1327474e-004
|
||||
-5.8709305e-004
|
||||
-5.6778026e-004
|
||||
-5.4665656e-004
|
||||
-5.2265643e-004
|
||||
-5.0407143e-004
|
||||
-4.8937912e-004
|
||||
-4.8752280e-004
|
||||
-4.9475181e-004
|
||||
-5.6176926e-004
|
||||
-5.5252865e-004
|
||||
@@ -0,0 +1,102 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Stretch2d(torch.nn.Module):
|
||||
def __init__(self, x_scale, y_scale, mode="nearest"):
|
||||
super().__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x (Tensor): Input tensor (B, C, F, T).
|
||||
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
||||
"""
|
||||
return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
|
||||
|
||||
|
||||
class UpsampleNetwork(torch.nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
upsample_factors,
|
||||
nonlinear_activation=None,
|
||||
nonlinear_activation_params={},
|
||||
interpolate_mode="nearest",
|
||||
freq_axis_kernel_size=1,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_causal_conv = use_causal_conv
|
||||
self.up_layers = torch.nn.ModuleList()
|
||||
for scale in upsample_factors:
|
||||
# interpolation layer
|
||||
stretch = Stretch2d(scale, 1, interpolate_mode)
|
||||
self.up_layers += [stretch]
|
||||
|
||||
# conv layer
|
||||
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
|
||||
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
||||
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
||||
if use_causal_conv:
|
||||
padding = (freq_axis_padding, scale * 2)
|
||||
else:
|
||||
padding = (freq_axis_padding, scale)
|
||||
conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
||||
self.up_layers += [conv]
|
||||
|
||||
# nonlinear
|
||||
if nonlinear_activation is not None:
|
||||
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||
self.up_layers += [nonlinear]
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c : (B, C, T_in).
|
||||
Tensor: (B, C, T_upsample)
|
||||
"""
|
||||
c = c.unsqueeze(1) # (B, 1, C, T)
|
||||
for f in self.up_layers:
|
||||
c = f(c)
|
||||
return c.squeeze(1) # (B, C, T')
|
||||
|
||||
|
||||
class ConvUpsample(torch.nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
upsample_factors,
|
||||
nonlinear_activation=None,
|
||||
nonlinear_activation_params={},
|
||||
interpolate_mode="nearest",
|
||||
freq_axis_kernel_size=1,
|
||||
aux_channels=80,
|
||||
aux_context_window=0,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.aux_context_window = aux_context_window
|
||||
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
||||
# To capture wide-context information in conditional features
|
||||
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
||||
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
|
||||
self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
|
||||
self.upsample = UpsampleNetwork(
|
||||
upsample_factors=upsample_factors,
|
||||
nonlinear_activation=nonlinear_activation,
|
||||
nonlinear_activation_params=nonlinear_activation_params,
|
||||
interpolate_mode=interpolate_mode,
|
||||
freq_axis_kernel_size=freq_axis_kernel_size,
|
||||
use_causal_conv=use_causal_conv,
|
||||
)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c : (B, C, T_in).
|
||||
Tensor: (B, C, T_upsampled),
|
||||
"""
|
||||
c_ = self.conv_in(c)
|
||||
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
|
||||
return self.upsample(c)
|
||||
@@ -0,0 +1,166 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
nn.init.orthogonal_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Positional encoding with noise level conditioning"""
|
||||
|
||||
def __init__(self, n_channels, max_len=10000):
|
||||
super().__init__()
|
||||
self.n_channels = n_channels
|
||||
self.max_len = max_len
|
||||
self.C = 5000
|
||||
self.pe = torch.zeros(0, 0)
|
||||
|
||||
def forward(self, x, noise_level):
|
||||
if x.shape[2] > self.pe.shape[1]:
|
||||
self.init_pe_matrix(x.shape[1], x.shape[2], x)
|
||||
return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C
|
||||
|
||||
def init_pe_matrix(self, n_channels, max_len, x):
|
||||
pe = torch.zeros(max_len, n_channels)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
|
||||
|
||||
pe[:, 0::2] = torch.sin(position / div_term)
|
||||
pe[:, 1::2] = torch.cos(position / div_term)
|
||||
self.pe = pe.transpose(0, 1).to(x)
|
||||
|
||||
|
||||
class FiLM(nn.Module):
|
||||
def __init__(self, input_size, output_size):
|
||||
super().__init__()
|
||||
self.encoding = PositionalEncoding(input_size)
|
||||
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
|
||||
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
|
||||
|
||||
nn.init.xavier_uniform_(self.input_conv.weight)
|
||||
nn.init.xavier_uniform_(self.output_conv.weight)
|
||||
nn.init.zeros_(self.input_conv.bias)
|
||||
nn.init.zeros_(self.output_conv.bias)
|
||||
|
||||
def forward(self, x, noise_scale):
|
||||
o = self.input_conv(x)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.encoding(o, noise_scale)
|
||||
shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
|
||||
return shift, scale
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.input_conv, "weight")
|
||||
remove_parametrizations(self.output_conv, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.input_conv = weight_norm(self.input_conv)
|
||||
self.output_conv = weight_norm(self.output_conv)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def shif_and_scale(x, scale, shift):
|
||||
o = shift + scale * x
|
||||
return o
|
||||
|
||||
|
||||
class UBlock(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, factor, dilation):
|
||||
super().__init__()
|
||||
assert isinstance(dilation, (list, tuple))
|
||||
assert len(dilation) == 4
|
||||
|
||||
self.factor = factor
|
||||
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||
self.main_block = nn.ModuleList(
|
||||
[
|
||||
Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]),
|
||||
]
|
||||
)
|
||||
self.out_block = nn.ModuleList(
|
||||
[
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, shift, scale):
|
||||
x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
|
||||
res = self.res_block(x_inter)
|
||||
o = F.leaky_relu(x_inter, 0.2)
|
||||
o = F.interpolate(o, size=x.shape[-1] * self.factor)
|
||||
o = self.main_block[0](o)
|
||||
o = shif_and_scale(o, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.main_block[1](o)
|
||||
res2 = res + o
|
||||
o = shif_and_scale(res2, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.out_block[0](o)
|
||||
o = shif_and_scale(o, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.out_block[1](o)
|
||||
o = o + res2
|
||||
return o
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.res_block, "weight")
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
remove_parametrizations(layer, "weight")
|
||||
for _, layer in enumerate(self.out_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
remove_parametrizations(layer, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
for idx, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.main_block[idx] = weight_norm(layer)
|
||||
for idx, layer in enumerate(self.out_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.out_block[idx] = weight_norm(layer)
|
||||
|
||||
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, factor):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||
self.main_block = nn.ModuleList(
|
||||
[
|
||||
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
size = x.shape[-1] // self.factor
|
||||
res = self.res_block(x)
|
||||
res = F.interpolate(res, size=size)
|
||||
o = F.interpolate(x, size=size)
|
||||
for layer in self.main_block:
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = layer(o)
|
||||
return o + res
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_parametrizations(self.res_block, "weight")
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
remove_parametrizations(layer, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
for idx, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.main_block[idx] = weight_norm(layer)
|
||||
Reference in New Issue
Block a user