Add files via upload
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,154 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions.normal import Normal
|
||||
|
||||
|
||||
def gaussian_loss(y_hat, y, log_std_min=-7.0):
|
||||
assert y_hat.dim() == 3
|
||||
assert y_hat.size(2) == 2
|
||||
mean = y_hat[:, :, :1]
|
||||
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
||||
# TODO: replace with pytorch dist
|
||||
log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)))
|
||||
return log_probs.squeeze().mean()
|
||||
|
||||
|
||||
def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0):
|
||||
assert y_hat.size(2) == 2
|
||||
mean = y_hat[:, :, :1]
|
||||
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
||||
dist = Normal(
|
||||
mean,
|
||||
torch.exp(log_std),
|
||||
)
|
||||
sample = dist.sample()
|
||||
sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor)
|
||||
del dist
|
||||
return sample
|
||||
|
||||
|
||||
def log_sum_exp(x):
|
||||
"""numerically stable log_sum_exp implementation that prevents overflow"""
|
||||
# TF ordering
|
||||
axis = len(x.size()) - 1
|
||||
m, _ = torch.max(x, dim=axis)
|
||||
m2, _ = torch.max(x, dim=axis, keepdim=True)
|
||||
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
|
||||
|
||||
|
||||
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
||||
def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True):
|
||||
if log_scale_min is None:
|
||||
log_scale_min = float(np.log(1e-14))
|
||||
y_hat = y_hat.permute(0, 2, 1)
|
||||
assert y_hat.dim() == 3
|
||||
assert y_hat.size(1) % 3 == 0
|
||||
nr_mix = y_hat.size(1) // 3
|
||||
|
||||
# (B x T x C)
|
||||
y_hat = y_hat.transpose(1, 2)
|
||||
|
||||
# unpack parameters. (B, T, num_mixtures) x 3
|
||||
logit_probs = y_hat[:, :, :nr_mix]
|
||||
means = y_hat[:, :, nr_mix : 2 * nr_mix]
|
||||
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
|
||||
|
||||
# B x T x 1 -> B x T x num_mixtures
|
||||
y = y.expand_as(means)
|
||||
|
||||
centered_y = y - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
|
||||
cdf_plus = torch.sigmoid(plus_in)
|
||||
min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
|
||||
cdf_min = torch.sigmoid(min_in)
|
||||
|
||||
# log probability for edge case of 0 (before scaling)
|
||||
# equivalent: torch.log(F.sigmoid(plus_in))
|
||||
log_cdf_plus = plus_in - F.softplus(plus_in)
|
||||
|
||||
# log probability for edge case of 255 (before scaling)
|
||||
# equivalent: (1 - F.sigmoid(min_in)).log()
|
||||
log_one_minus_cdf_min = -F.softplus(min_in)
|
||||
|
||||
# probability for all other cases
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
mid_in = inv_stdv * centered_y
|
||||
# log probability in the center of the bin, to be used in extreme cases
|
||||
# (not actually used in our code)
|
||||
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
|
||||
|
||||
# tf equivalent
|
||||
|
||||
# log_probs = tf.where(x < -0.999, log_cdf_plus,
|
||||
# tf.where(x > 0.999, log_one_minus_cdf_min,
|
||||
# tf.where(cdf_delta > 1e-5,
|
||||
# tf.log(tf.maximum(cdf_delta, 1e-12)),
|
||||
# log_pdf_mid - np.log(127.5))))
|
||||
|
||||
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
||||
# for num_classes=65536 case? 1e-7? not sure..
|
||||
inner_inner_cond = (cdf_delta > 1e-5).float()
|
||||
|
||||
inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * (
|
||||
log_pdf_mid - np.log((num_classes - 1) / 2)
|
||||
)
|
||||
inner_cond = (y > 0.999).float()
|
||||
inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
|
||||
cond = (y < -0.999).float()
|
||||
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
|
||||
|
||||
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
||||
|
||||
if reduce:
|
||||
return -torch.mean(log_sum_exp(log_probs))
|
||||
return -log_sum_exp(log_probs).unsqueeze(-1)
|
||||
|
||||
|
||||
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
||||
"""
|
||||
Sample from discretized mixture of logistic distributions
|
||||
Args:
|
||||
y (Tensor): :math:`[B, C, T]`
|
||||
log_scale_min (float): Log scale minimum value
|
||||
Returns:
|
||||
Tensor: sample in range of [-1, 1].
|
||||
"""
|
||||
if log_scale_min is None:
|
||||
log_scale_min = float(np.log(1e-14))
|
||||
assert y.size(1) % 3 == 0
|
||||
nr_mix = y.size(1) // 3
|
||||
|
||||
# B x T x C
|
||||
y = y.transpose(1, 2)
|
||||
logit_probs = y[:, :, :nr_mix]
|
||||
|
||||
# sample mixture indicator from softmax
|
||||
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||
temp = logit_probs.data - torch.log(-torch.log(temp))
|
||||
_, argmax = temp.max(dim=-1)
|
||||
|
||||
# (B, T) -> (B, T, nr_mix)
|
||||
one_hot = to_one_hot(argmax, nr_mix)
|
||||
# select logistic parameters
|
||||
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
|
||||
log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
|
||||
# sample from logistic & clip to interval
|
||||
# we don't actually round to the nearest 8bit value when sampling
|
||||
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
|
||||
|
||||
x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, n, fill_with=1.0):
|
||||
# we perform one hot encore with respect to the last axis
|
||||
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor)
|
||||
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
||||
return one_hot
|
||||
@@ -0,0 +1,72 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def interpolate_vocoder_input(scale_factor, spec):
|
||||
"""Interpolate spectrogram by the scale factor.
|
||||
It is mainly used to match the sampling rates of
|
||||
the tts and vocoder models.
|
||||
|
||||
Args:
|
||||
scale_factor (float): scale factor to interpolate the spectrogram
|
||||
spec (np.array): spectrogram to be interpolated
|
||||
|
||||
Returns:
|
||||
torch.tensor: interpolated spectrogram.
|
||||
"""
|
||||
print(" > before interpolation :", spec.shape)
|
||||
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
|
||||
spec = torch.nn.functional.interpolate(
|
||||
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
print(" > after interpolation :", spec.shape)
|
||||
return spec
|
||||
|
||||
|
||||
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
|
||||
"""Plot the predicted and the real waveform and their spectrograms.
|
||||
|
||||
Args:
|
||||
y_hat (torch.tensor): Predicted waveform.
|
||||
y (torch.tensor): Real waveform.
|
||||
ap (AudioProcessor): Audio processor used to process the waveform.
|
||||
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict: output figures keyed by the name of the figures.
|
||||
""" """Plot vocoder model results"""
|
||||
if name_prefix is None:
|
||||
name_prefix = ""
|
||||
|
||||
# select an instance from batch
|
||||
y_hat = y_hat[0].squeeze().detach().cpu().numpy()
|
||||
y = y[0].squeeze().detach().cpu().numpy()
|
||||
|
||||
spec_fake = ap.melspectrogram(y_hat).T
|
||||
spec_real = ap.melspectrogram(y).T
|
||||
spec_diff = np.abs(spec_fake - spec_real)
|
||||
|
||||
# plot figure and save it
|
||||
fig_wave = plt.figure()
|
||||
plt.subplot(2, 1, 1)
|
||||
plt.plot(y)
|
||||
plt.title("groundtruth speech")
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(y_hat)
|
||||
plt.title("generated speech")
|
||||
plt.tight_layout()
|
||||
plt.close()
|
||||
|
||||
figures = {
|
||||
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake),
|
||||
name_prefix + "spectrogram/real": plot_spectrogram(spec_real),
|
||||
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff),
|
||||
name_prefix + "speech_comparison": fig_wave,
|
||||
}
|
||||
return figures
|
||||
Reference in New Issue
Block a user