Add files via upload
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,485 @@
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import scipy
|
||||
import soundfile as sf
|
||||
from librosa import magphase, pyin
|
||||
|
||||
# For using kwargs
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
|
||||
def build_mel_basis(
|
||||
*,
|
||||
sample_rate: int = None,
|
||||
fft_size: int = None,
|
||||
num_mels: int = None,
|
||||
mel_fmax: int = None,
|
||||
mel_fmin: int = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Build melspectrogram basis.
|
||||
|
||||
Returns:
|
||||
np.ndarray: melspectrogram basis.
|
||||
"""
|
||||
if mel_fmax is not None:
|
||||
assert mel_fmax <= sample_rate // 2
|
||||
assert mel_fmax - mel_fmin > 0
|
||||
return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax)
|
||||
|
||||
|
||||
def millisec_to_length(
|
||||
*, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute hop and window length from milliseconds.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: hop length and window length for STFT.
|
||||
"""
|
||||
factor = frame_length_ms / frame_shift_ms
|
||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
||||
win_length = int(frame_length_ms / 1000.0 * sample_rate)
|
||||
hop_length = int(win_length / float(factor))
|
||||
return win_length, hop_length
|
||||
|
||||
|
||||
def _log(x, base):
|
||||
if base == 10:
|
||||
return np.log10(x)
|
||||
return np.log(x)
|
||||
|
||||
|
||||
def _exp(x, base):
|
||||
if base == 10:
|
||||
return np.power(10, x)
|
||||
return np.exp(x)
|
||||
|
||||
|
||||
def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert amplitude values to decibels.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Amplitude spectrogram.
|
||||
gain (float): Gain factor. Defaults to 1.
|
||||
base (int): Logarithm base. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decibels spectrogram.
|
||||
"""
|
||||
assert (x < 0).sum() == 0, " [!] Input values must be non-negative."
|
||||
return gain * _log(np.maximum(1e-8, x), base)
|
||||
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert decibels spectrogram to amplitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Decibels spectrogram.
|
||||
gain (float): Gain factor. Defaults to 1.
|
||||
base (int): Logarithm base. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Amplitude spectrogram.
|
||||
"""
|
||||
return _exp(x / gain, base)
|
||||
|
||||
|
||||
def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Audio signal.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Preemphasis coeff is set to 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decorrelated audio signal.
|
||||
"""
|
||||
if coef == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1, -coef], [1], x)
|
||||
|
||||
|
||||
def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||
"""Reverse pre-emphasis."""
|
||||
if coef == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1], [1, -coef], x)
|
||||
|
||||
|
||||
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
||||
Args:
|
||||
spec (np.ndarray): Normalized full scale linear spectrogram.
|
||||
|
||||
Shapes:
|
||||
- spec: :math:`[C, T]`
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized melspectrogram.
|
||||
"""
|
||||
return np.dot(mel_basis, spec)
|
||||
|
||||
|
||||
def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a melspectrogram to full scale spectrogram."""
|
||||
assert (mel < 0).sum() == 0, " [!] Input values must be non-negative."
|
||||
inv_mel_basis = np.linalg.pinv(mel_basis)
|
||||
return np.maximum(1e-10, np.dot(inv_mel_basis, mel))
|
||||
|
||||
|
||||
def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Compute a spectrogram from a waveform.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||
|
||||
Returns:
|
||||
np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length`
|
||||
"""
|
||||
D = stft(y=wav, **kwargs)
|
||||
S = np.abs(D)
|
||||
return S.astype(np.float32)
|
||||
|
||||
|
||||
def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray:
|
||||
"""Compute a melspectrogram from a waveform."""
|
||||
D = stft(y=wav, **kwargs)
|
||||
S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs)
|
||||
return S.astype(np.float32)
|
||||
|
||||
|
||||
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = spec.copy()
|
||||
return griffin_lim(spec=S**power, **kwargs)
|
||||
|
||||
|
||||
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = mel.copy()
|
||||
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
|
||||
return griffin_lim(spec=S**power, **kwargs)
|
||||
|
||||
|
||||
### STFT and ISTFT ###
|
||||
def stft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
fft_size: int = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
pad_mode: str = "reflect",
|
||||
window: str = "hann",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Librosa STFT wrapper.
|
||||
|
||||
Check http://librosa.org/doc/main/generated/librosa.stft.html argument details.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.stft(
|
||||
y=y,
|
||||
n_fft=fft_size,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
pad_mode=pad_mode,
|
||||
window=window,
|
||||
center=center,
|
||||
)
|
||||
|
||||
|
||||
def istft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
window: str = "hann",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Librosa iSTFT wrapper.
|
||||
|
||||
Check http://librosa.org/doc/main/generated/librosa.istft.html argument details.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window)
|
||||
|
||||
|
||||
def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray:
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*spec.shape))
|
||||
S_complex = np.abs(spec).astype(complex)
|
||||
y = istft(y=S_complex * angles, **kwargs)
|
||||
if not np.isfinite(y).all():
|
||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
||||
return np.array([0.0])
|
||||
for _ in range(num_iter):
|
||||
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))
|
||||
y = istft(y=S_complex * angles, **kwargs)
|
||||
return y
|
||||
|
||||
|
||||
def compute_stft_paddings(
|
||||
*, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
||||
(first and final frames)"""
|
||||
pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0]
|
||||
if not pad_two_sides:
|
||||
return 0, pad
|
||||
return pad // 2, pad // 2 + pad % 2
|
||||
|
||||
|
||||
def compute_f0(
|
||||
*,
|
||||
x: np.ndarray = None,
|
||||
pitch_fmax: float = None,
|
||||
pitch_fmin: float = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
sample_rate: int = None,
|
||||
stft_pad_mode: str = "reflect",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||
pitch_fmax (float): Pitch max value.
|
||||
pitch_fmin (float): Pitch min value.
|
||||
hop_length (int): Number of frames between STFT columns.
|
||||
win_length (int): STFT window length.
|
||||
sample_rate (int): Audio sampling rate.
|
||||
stft_pad_mode (str): Padding mode for STFT.
|
||||
center (bool): Centered padding.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
|
||||
|
||||
Examples:
|
||||
>>> WAV_FILE = filename = librosa.example('vibeace')
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio import AudioProcessor
|
||||
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||
>>> pitch = ap.compute_f0(wav)
|
||||
"""
|
||||
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||
assert pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
|
||||
|
||||
f0, voiced_mask, _ = pyin(
|
||||
y=x.astype(np.double),
|
||||
fmin=pitch_fmin,
|
||||
fmax=pitch_fmax,
|
||||
sr=sample_rate,
|
||||
frame_length=win_length,
|
||||
win_length=win_length // 2,
|
||||
hop_length=hop_length,
|
||||
pad_mode=stft_pad_mode,
|
||||
center=center,
|
||||
n_thresholds=100,
|
||||
beta_parameters=(2, 18),
|
||||
boltzmann_parameter=2,
|
||||
resolution=0.1,
|
||||
max_transition_rate=35.92,
|
||||
switch_prob=0.01,
|
||||
no_trough_prob=0.01,
|
||||
)
|
||||
f0[~voiced_mask] = 0.0
|
||||
|
||||
return f0
|
||||
|
||||
|
||||
def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray:
|
||||
"""Compute energy of a waveform using the same parameters used for computing melspectrogram.
|
||||
Args:
|
||||
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||
Returns:
|
||||
np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length`
|
||||
Examples:
|
||||
>>> WAV_FILE = filename = librosa.example('vibeace')
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio import AudioProcessor
|
||||
>>> conf = BaseAudioConfig()
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||
>>> energy = ap.compute_energy(wav)
|
||||
"""
|
||||
x = stft(y=y, **kwargs)
|
||||
mag, _ = magphase(x)
|
||||
energy = np.sqrt(np.sum(mag**2, axis=0))
|
||||
return energy
|
||||
|
||||
|
||||
### Audio Processing ###
|
||||
def find_endpoint(
|
||||
*,
|
||||
wav: np.ndarray = None,
|
||||
trim_db: float = -40,
|
||||
sample_rate: int = None,
|
||||
min_silence_sec=0.8,
|
||||
gain: float = None,
|
||||
base: int = None,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""Find the last point without silence at the end of a audio signal.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Audio signal.
|
||||
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
|
||||
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
|
||||
gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None.
|
||||
base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
int: Last point without silence.
|
||||
"""
|
||||
window_length = int(sample_rate * min_silence_sec)
|
||||
hop_length = int(window_length / 4)
|
||||
threshold = db_to_amp(x=-trim_db, gain=gain, base=base)
|
||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||
if np.max(wav[x : x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
|
||||
|
||||
def trim_silence(
|
||||
*,
|
||||
wav: np.ndarray = None,
|
||||
sample_rate: int = None,
|
||||
trim_db: float = None,
|
||||
win_length: int = None,
|
||||
hop_length: int = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||
margin = int(sample_rate * 0.01)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0]
|
||||
|
||||
|
||||
def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray:
|
||||
"""Normalize the volume of an audio signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
coef (float): Coefficient to rescale the maximum value. Defaults to 0.95.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Volume normalized waveform.
|
||||
"""
|
||||
return x / abs(x).max() * coef
|
||||
|
||||
|
||||
def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||
r = 10 ** (db_level / 20)
|
||||
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
||||
return wav * a
|
||||
|
||||
|
||||
def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||
"""Normalize the volume based on RMS of the signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
db_level (float): Target dB level in RMS. Defaults to -27.0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: RMS normalized waveform.
|
||||
"""
|
||||
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
||||
wav = rms_norm(wav=x, db_level=db_level)
|
||||
return wav
|
||||
|
||||
|
||||
def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray:
|
||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||
|
||||
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||
resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if resample:
|
||||
# loading with resampling. It is significantly slower.
|
||||
x, _ = librosa.load(filename, sr=sample_rate)
|
||||
else:
|
||||
# SF is faster than librosa for loading files
|
||||
x, _ = sf.read(filename)
|
||||
return x
|
||||
|
||||
|
||||
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None:
|
||||
"""Save float waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
|
||||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
|
||||
wav_norm = wav_norm.astype(np.int16)
|
||||
if pipe_out:
|
||||
wav_buffer = BytesIO()
|
||||
scipy.io.wavfile.write(wav_buffer, sample_rate, wav_norm)
|
||||
wav_buffer.seek(0)
|
||||
pipe_out.buffer.write(wav_buffer.read())
|
||||
scipy.io.wavfile.write(path, sample_rate, wav_norm)
|
||||
|
||||
|
||||
def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||
mu = 2**mulaw_qc - 1
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
signal = (signal + 1) / 2 * mu + 0.5
|
||||
return np.floor(
|
||||
signal,
|
||||
)
|
||||
|
||||
|
||||
def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||
"""Recovers waveform from quantized values."""
|
||||
mu = 2**mulaw_qc - 1
|
||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||
return x
|
||||
|
||||
|
||||
def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray:
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
|
||||
|
||||
def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray:
|
||||
"""Quantize a waveform to a given number of bits.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
||||
quantize_bits (int): Number of quantization bits.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Quantized waveform.
|
||||
"""
|
||||
return (x + 1.0) * (2**quantize_bits - 1) / 2
|
||||
|
||||
|
||||
def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray:
|
||||
"""Dequantize a waveform from the given number of bits."""
|
||||
return 2 * x / (2**quantize_bits - 1) - 1
|
||||
@@ -0,0 +1,633 @@
|
||||
from io import BytesIO
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
import scipy.signal
|
||||
|
||||
from TTS.tts.utils.helpers import StandardScaler
|
||||
from TTS.utils.audio.numpy_transforms import (
|
||||
amp_to_db,
|
||||
build_mel_basis,
|
||||
compute_f0,
|
||||
db_to_amp,
|
||||
deemphasis,
|
||||
find_endpoint,
|
||||
griffin_lim,
|
||||
load_wav,
|
||||
mel_to_spec,
|
||||
millisec_to_length,
|
||||
preemphasis,
|
||||
rms_volume_norm,
|
||||
spec_to_mel,
|
||||
stft,
|
||||
trim_silence,
|
||||
volume_norm,
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
|
||||
|
||||
class AudioProcessor(object):
|
||||
"""Audio Processor for TTS.
|
||||
|
||||
Note:
|
||||
All the class arguments are set to default values to enable a flexible initialization
|
||||
of the class with the model config. They are not meaningful for all the arguments.
|
||||
|
||||
Args:
|
||||
sample_rate (int, optional):
|
||||
target audio sampling rate. Defaults to None.
|
||||
|
||||
resample (bool, optional):
|
||||
enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
|
||||
|
||||
num_mels (int, optional):
|
||||
number of melspectrogram dimensions. Defaults to None.
|
||||
|
||||
log_func (int, optional):
|
||||
log exponent used for converting spectrogram aplitude to DB.
|
||||
|
||||
min_level_db (int, optional):
|
||||
minimum db threshold for the computed melspectrograms. Defaults to None.
|
||||
|
||||
frame_shift_ms (int, optional):
|
||||
milliseconds of frames between STFT columns. Defaults to None.
|
||||
|
||||
frame_length_ms (int, optional):
|
||||
milliseconds of STFT window length. Defaults to None.
|
||||
|
||||
hop_length (int, optional):
|
||||
number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
|
||||
|
||||
win_length (int, optional):
|
||||
STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
|
||||
|
||||
ref_level_db (int, optional):
|
||||
reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
|
||||
|
||||
fft_size (int, optional):
|
||||
FFT window size for STFT. Defaults to 1024.
|
||||
|
||||
power (int, optional):
|
||||
Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
|
||||
|
||||
preemphasis (float, optional):
|
||||
Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
|
||||
|
||||
signal_norm (bool, optional):
|
||||
enable/disable signal normalization. Defaults to None.
|
||||
|
||||
symmetric_norm (bool, optional):
|
||||
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
|
||||
|
||||
max_norm (float, optional):
|
||||
```k``` defining the normalization range. Defaults to None.
|
||||
|
||||
mel_fmin (int, optional):
|
||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
pitch_fmin (int, optional):
|
||||
minimum filter frequency for computing pitch. Defaults to None.
|
||||
|
||||
pitch_fmax (int, optional):
|
||||
maximum filter frequency for computing pitch. Defaults to None.
|
||||
|
||||
spec_gain (int, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 20.
|
||||
|
||||
stft_pad_mode (str, optional):
|
||||
Padding mode for STFT. Defaults to 'reflect'.
|
||||
|
||||
clip_norm (bool, optional):
|
||||
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||
|
||||
griffin_lim_iters (int, optional):
|
||||
Number of GriffinLim iterations. Defaults to None.
|
||||
|
||||
do_trim_silence (bool, optional):
|
||||
enable/disable silence trimming when loading the audio signal. Defaults to False.
|
||||
|
||||
trim_db (int, optional):
|
||||
DB threshold used for silence trimming. Defaults to 60.
|
||||
|
||||
do_sound_norm (bool, optional):
|
||||
enable/disable signal normalization. Defaults to False.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||
|
||||
do_amp_to_db_mel (bool, optional):
|
||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||
|
||||
do_rms_norm (bool, optional):
|
||||
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
|
||||
|
||||
db_level (int, optional):
|
||||
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
|
||||
|
||||
stats_path (str, optional):
|
||||
Path to the computed stats file. Defaults to None.
|
||||
|
||||
verbose (bool, optional):
|
||||
enable/disable logging. Defaults to True.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=None,
|
||||
resample=False,
|
||||
num_mels=None,
|
||||
log_func="np.log10",
|
||||
min_level_db=None,
|
||||
frame_shift_ms=None,
|
||||
frame_length_ms=None,
|
||||
hop_length=None,
|
||||
win_length=None,
|
||||
ref_level_db=None,
|
||||
fft_size=1024,
|
||||
power=None,
|
||||
preemphasis=0.0,
|
||||
signal_norm=None,
|
||||
symmetric_norm=None,
|
||||
max_norm=None,
|
||||
mel_fmin=None,
|
||||
mel_fmax=None,
|
||||
pitch_fmax=None,
|
||||
pitch_fmin=None,
|
||||
spec_gain=20,
|
||||
stft_pad_mode="reflect",
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=None,
|
||||
do_trim_silence=False,
|
||||
trim_db=60,
|
||||
do_sound_norm=False,
|
||||
do_amp_to_db_linear=True,
|
||||
do_amp_to_db_mel=True,
|
||||
do_rms_norm=False,
|
||||
db_level=None,
|
||||
stats_path=None,
|
||||
verbose=True,
|
||||
**_,
|
||||
):
|
||||
# setup class attributed
|
||||
self.sample_rate = sample_rate
|
||||
self.resample = resample
|
||||
self.num_mels = num_mels
|
||||
self.log_func = log_func
|
||||
self.min_level_db = min_level_db or 0
|
||||
self.frame_shift_ms = frame_shift_ms
|
||||
self.frame_length_ms = frame_length_ms
|
||||
self.ref_level_db = ref_level_db
|
||||
self.fft_size = fft_size
|
||||
self.power = power
|
||||
self.preemphasis = preemphasis
|
||||
self.griffin_lim_iters = griffin_lim_iters
|
||||
self.signal_norm = signal_norm
|
||||
self.symmetric_norm = symmetric_norm
|
||||
self.mel_fmin = mel_fmin or 0
|
||||
self.mel_fmax = mel_fmax
|
||||
self.pitch_fmin = pitch_fmin
|
||||
self.pitch_fmax = pitch_fmax
|
||||
self.spec_gain = float(spec_gain)
|
||||
self.stft_pad_mode = stft_pad_mode
|
||||
self.max_norm = 1.0 if max_norm is None else float(max_norm)
|
||||
self.clip_norm = clip_norm
|
||||
self.do_trim_silence = do_trim_silence
|
||||
self.trim_db = trim_db
|
||||
self.do_sound_norm = do_sound_norm
|
||||
self.do_amp_to_db_linear = do_amp_to_db_linear
|
||||
self.do_amp_to_db_mel = do_amp_to_db_mel
|
||||
self.do_rms_norm = do_rms_norm
|
||||
self.db_level = db_level
|
||||
self.stats_path = stats_path
|
||||
# setup exp_func for db to amp conversion
|
||||
if log_func == "np.log":
|
||||
self.base = np.e
|
||||
elif log_func == "np.log10":
|
||||
self.base = 10
|
||||
else:
|
||||
raise ValueError(" [!] unknown `log_func` value.")
|
||||
# setup stft parameters
|
||||
if hop_length is None:
|
||||
# compute stft parameters from given time values
|
||||
self.win_length, self.hop_length = millisec_to_length(
|
||||
frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate
|
||||
)
|
||||
else:
|
||||
# use stft parameters from config file
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
assert min_level_db != 0.0, " [!] min_level_db is 0"
|
||||
assert (
|
||||
self.win_length <= self.fft_size
|
||||
), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
|
||||
members = vars(self)
|
||||
if verbose:
|
||||
print(" > Setting up Audio Processor...")
|
||||
for key, value in members.items():
|
||||
print(" | > {}:{}".format(key, value))
|
||||
# create spectrogram utils
|
||||
self.mel_basis = build_mel_basis(
|
||||
sample_rate=self.sample_rate,
|
||||
fft_size=self.fft_size,
|
||||
num_mels=self.num_mels,
|
||||
mel_fmax=self.mel_fmax,
|
||||
mel_fmin=self.mel_fmin,
|
||||
)
|
||||
# setup scaler
|
||||
if stats_path and signal_norm:
|
||||
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
||||
self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)
|
||||
self.signal_norm = True
|
||||
self.max_norm = None
|
||||
self.clip_norm = None
|
||||
self.symmetric_norm = None
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit", verbose=True):
|
||||
if "audio" in config:
|
||||
return AudioProcessor(verbose=verbose, **config.audio)
|
||||
return AudioProcessor(verbose=verbose, **config)
|
||||
|
||||
### normalization ###
|
||||
def normalize(self, S: np.ndarray) -> np.ndarray:
|
||||
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
|
||||
|
||||
Args:
|
||||
S (np.ndarray): Spectrogram to normalize.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Mean and variance is computed from incompatible parameters.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized spectrogram.
|
||||
"""
|
||||
# pylint: disable=no-else-return
|
||||
S = S.copy()
|
||||
if self.signal_norm:
|
||||
# mean-var scaling
|
||||
if hasattr(self, "mel_scaler"):
|
||||
if S.shape[0] == self.num_mels:
|
||||
return self.mel_scaler.transform(S.T).T
|
||||
elif S.shape[0] == self.fft_size / 2:
|
||||
return self.linear_scaler.transform(S.T).T
|
||||
else:
|
||||
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
|
||||
# range normalization
|
||||
S -= self.ref_level_db # discard certain range of DB assuming it is air noise
|
||||
S_norm = (S - self.min_level_db) / (-self.min_level_db)
|
||||
if self.symmetric_norm:
|
||||
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
|
||||
if self.clip_norm:
|
||||
S_norm = np.clip(
|
||||
S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
|
||||
)
|
||||
return S_norm
|
||||
else:
|
||||
S_norm = self.max_norm * S_norm
|
||||
if self.clip_norm:
|
||||
S_norm = np.clip(S_norm, 0, self.max_norm)
|
||||
return S_norm
|
||||
else:
|
||||
return S
|
||||
|
||||
def denormalize(self, S: np.ndarray) -> np.ndarray:
|
||||
"""Denormalize spectrogram values.
|
||||
|
||||
Args:
|
||||
S (np.ndarray): Spectrogram to denormalize.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Mean and variance are incompatible.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Denormalized spectrogram.
|
||||
"""
|
||||
# pylint: disable=no-else-return
|
||||
S_denorm = S.copy()
|
||||
if self.signal_norm:
|
||||
# mean-var scaling
|
||||
if hasattr(self, "mel_scaler"):
|
||||
if S_denorm.shape[0] == self.num_mels:
|
||||
return self.mel_scaler.inverse_transform(S_denorm.T).T
|
||||
elif S_denorm.shape[0] == self.fft_size / 2:
|
||||
return self.linear_scaler.inverse_transform(S_denorm.T).T
|
||||
else:
|
||||
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
|
||||
if self.symmetric_norm:
|
||||
if self.clip_norm:
|
||||
S_denorm = np.clip(
|
||||
S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
|
||||
)
|
||||
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
|
||||
return S_denorm + self.ref_level_db
|
||||
else:
|
||||
if self.clip_norm:
|
||||
S_denorm = np.clip(S_denorm, 0, self.max_norm)
|
||||
S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db
|
||||
return S_denorm + self.ref_level_db
|
||||
else:
|
||||
return S_denorm
|
||||
|
||||
### Mean-STD scaling ###
|
||||
def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]:
|
||||
"""Loading mean and variance statistics from a `npy` file.
|
||||
|
||||
Args:
|
||||
stats_path (str): Path to the `npy` file containing
|
||||
|
||||
Returns:
|
||||
Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to
|
||||
compute them.
|
||||
"""
|
||||
stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg
|
||||
mel_mean = stats["mel_mean"]
|
||||
mel_std = stats["mel_std"]
|
||||
linear_mean = stats["linear_mean"]
|
||||
linear_std = stats["linear_std"]
|
||||
stats_config = stats["audio_config"]
|
||||
# check all audio parameters used for computing stats
|
||||
skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"]
|
||||
for key in stats_config.keys():
|
||||
if key in skip_parameters:
|
||||
continue
|
||||
if key not in ["sample_rate", "trim_db"]:
|
||||
assert (
|
||||
stats_config[key] == self.__dict__[key]
|
||||
), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
def setup_scaler(
|
||||
self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray
|
||||
) -> None:
|
||||
"""Initialize scaler objects used in mean-std normalization.
|
||||
|
||||
Args:
|
||||
mel_mean (np.ndarray): Mean for melspectrograms.
|
||||
mel_std (np.ndarray): STD for melspectrograms.
|
||||
linear_mean (np.ndarray): Mean for full scale spectrograms.
|
||||
linear_std (np.ndarray): STD for full scale spectrograms.
|
||||
"""
|
||||
self.mel_scaler = StandardScaler()
|
||||
self.mel_scaler.set_stats(mel_mean, mel_std)
|
||||
self.linear_scaler = StandardScaler()
|
||||
self.linear_scaler.set_stats(linear_mean, linear_std)
|
||||
|
||||
### Preemphasis ###
|
||||
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Audio signal.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Preemphasis coeff is set to 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decorrelated audio signal.
|
||||
"""
|
||||
return preemphasis(x=x, coef=self.preemphasis)
|
||||
|
||||
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reverse pre-emphasis."""
|
||||
return deemphasis(x=x, coef=self.preemphasis)
|
||||
|
||||
### SPECTROGRAMs ###
|
||||
def spectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a spectrogram from a waveform.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Spectrogram.
|
||||
"""
|
||||
if self.preemphasis != 0:
|
||||
y = self.apply_preemphasis(y)
|
||||
D = stft(
|
||||
y=y,
|
||||
fft_size=self.fft_size,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
if self.do_amp_to_db_linear:
|
||||
S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
|
||||
else:
|
||||
S = np.abs(D)
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a melspectrogram from a waveform."""
|
||||
if self.preemphasis != 0:
|
||||
y = self.apply_preemphasis(y)
|
||||
D = stft(
|
||||
y=y,
|
||||
fft_size=self.fft_size,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis)
|
||||
if self.do_amp_to_db_mel:
|
||||
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = self.denormalize(spectrogram)
|
||||
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||
# Reconstruct phase
|
||||
W = self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||
|
||||
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
D = self.denormalize(mel_spectrogram)
|
||||
S = db_to_amp(x=D, gain=self.spec_gain, base=self.base)
|
||||
S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear
|
||||
W = self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||
|
||||
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
||||
Args:
|
||||
linear_spec (np.ndarray): Normalized full scale linear spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized melspectrogram.
|
||||
"""
|
||||
S = self.denormalize(linear_spec)
|
||||
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||
S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
|
||||
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||
mel = self.normalize(S)
|
||||
return mel
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
return griffin_lim(
|
||||
spec=S,
|
||||
num_iter=self.griffin_lim_iters,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
fft_size=self.fft_size,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
|
||||
def compute_f0(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Pitch.
|
||||
|
||||
Examples:
|
||||
>>> WAV_FILE = filename = librosa.example('vibeace')
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio import AudioProcessor
|
||||
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||
>>> pitch = ap.compute_f0(wav)
|
||||
"""
|
||||
# align F0 length to the spectrogram length
|
||||
if len(x) % self.hop_length == 0:
|
||||
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
|
||||
|
||||
f0 = compute_f0(
|
||||
x=x,
|
||||
pitch_fmax=self.pitch_fmax,
|
||||
pitch_fmin=self.pitch_fmin,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
sample_rate=self.sample_rate,
|
||||
stft_pad_mode=self.stft_pad_mode,
|
||||
center=True,
|
||||
)
|
||||
|
||||
return f0
|
||||
|
||||
### Audio Processing ###
|
||||
def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int:
|
||||
"""Find the last point without silence at the end of a audio signal.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Audio signal.
|
||||
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
|
||||
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
int: Last point without silence.
|
||||
"""
|
||||
return find_endpoint(
|
||||
wav=wav,
|
||||
trim_db=self.trim_db,
|
||||
sample_rate=self.sample_rate,
|
||||
min_silence_sec=min_silence_sec,
|
||||
gain=self.spec_gain,
|
||||
base=self.base,
|
||||
)
|
||||
|
||||
def trim_silence(self, wav):
|
||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||
return trim_silence(
|
||||
wav=wav,
|
||||
sample_rate=self.sample_rate,
|
||||
trim_db=self.trim_db,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sound_norm(x: np.ndarray) -> np.ndarray:
|
||||
"""Normalize the volume of an audio signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Volume normalized waveform.
|
||||
"""
|
||||
return volume_norm(x=x)
|
||||
|
||||
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
||||
"""Normalize the volume based on RMS of the signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: RMS normalized waveform.
|
||||
"""
|
||||
if db_level is None:
|
||||
db_level = self.db_level
|
||||
return rms_volume_norm(x=x, db_level=db_level)
|
||||
|
||||
### save and load ###
|
||||
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||
|
||||
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if sr is not None:
|
||||
x = load_wav(filename=filename, sample_rate=sr, resample=True)
|
||||
else:
|
||||
x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample)
|
||||
if self.do_trim_silence:
|
||||
try:
|
||||
x = self.trim_silence(x)
|
||||
except ValueError:
|
||||
print(f" [!] File cannot be trimmed for silence - {filename}")
|
||||
if self.do_sound_norm:
|
||||
x = self.sound_norm(x)
|
||||
if self.do_rms_norm:
|
||||
x = self.rms_volume_norm(x, self.db_level)
|
||||
return x
|
||||
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None:
|
||||
"""Save a waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform to save.
|
||||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
if self.do_rms_norm:
|
||||
wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767
|
||||
else:
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
|
||||
wav_norm = wav_norm.astype(np.int16)
|
||||
if pipe_out:
|
||||
wav_buffer = BytesIO()
|
||||
scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm)
|
||||
wav_buffer.seek(0)
|
||||
pipe_out.buffer.write(wav_buffer.read())
|
||||
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm)
|
||||
|
||||
def get_duration(self, filename: str) -> float:
|
||||
"""Get the duration of a wav file using Librosa.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
"""
|
||||
return librosa.get_duration(filename=filename)
|
||||
@@ -0,0 +1,165 @@
|
||||
import librosa
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
||||
Args:
|
||||
|
||||
n_fft (int):
|
||||
FFT window size for STFT.
|
||||
|
||||
hop_length (int):
|
||||
number of frames between STFT columns.
|
||||
|
||||
win_length (int, optional):
|
||||
STFT window length.
|
||||
|
||||
pad_wav (bool, optional):
|
||||
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
||||
|
||||
window (str, optional):
|
||||
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
||||
|
||||
sample_rate (int, optional):
|
||||
target audio sampling rate. Defaults to None.
|
||||
|
||||
mel_fmin (int, optional):
|
||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
n_mels (int, optional):
|
||||
number of melspectrogram dimensions. Defaults to None.
|
||||
|
||||
use_mel (bool, optional):
|
||||
If True compute the melspectrograms otherwise. Defaults to False.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
||||
|
||||
spec_gain (float, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 1.0.
|
||||
|
||||
power (float, optional):
|
||||
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
||||
|
||||
use_htk (bool, optional):
|
||||
Use HTK formula in mel filter instead of Slaney.
|
||||
|
||||
mel_norm (None, 'slaney', or number, optional):
|
||||
If 'slaney', divide the triangular mel weights by the width of the mel band
|
||||
(area normalization).
|
||||
|
||||
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
||||
See `librosa.util.normalize` for a full description of supported norm values
|
||||
(including `+-np.inf`).
|
||||
|
||||
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
do_amp_to_db=False,
|
||||
spec_gain=1.0,
|
||||
power=None,
|
||||
use_htk=False,
|
||||
mel_norm="slaney",
|
||||
normalized=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_wav = pad_wav
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.do_amp_to_db = do_amp_to_db
|
||||
self.spec_gain = spec_gain
|
||||
self.power = power
|
||||
self.use_htk = use_htk
|
||||
self.mel_norm = mel_norm
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
self.normalized = normalized
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute spectrogram frames by torch based stft.
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform
|
||||
|
||||
Returns:
|
||||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [:math:`[B, 1, T]`]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=self.normalized,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
||||
|
||||
if self.power is not None:
|
||||
S = S**self.power
|
||||
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
if self.do_amp_to_db:
|
||||
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
sr=self.sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax,
|
||||
htk=self.use_htk,
|
||||
norm=self.mel_norm,
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
@staticmethod
|
||||
def _amp_to_db(x, spec_gain=1.0):
|
||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||
|
||||
@staticmethod
|
||||
def _db_to_amp(x, spec_gain=1.0):
|
||||
return torch.exp(x) / spec_gain
|
||||
@@ -0,0 +1,105 @@
|
||||
class TrainerCallback:
|
||||
@staticmethod
|
||||
def on_init_start(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_init_start"):
|
||||
trainer.model.module.on_init_start(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_init_start"):
|
||||
trainer.model.on_init_start(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_init_start"):
|
||||
trainer.criterion.on_init_start(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_init_start"):
|
||||
trainer.optimizer.on_init_start(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_init_end(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_init_end"):
|
||||
trainer.model.module.on_init_end(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_init_end"):
|
||||
trainer.model.on_init_end(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_init_end"):
|
||||
trainer.criterion.on_init_end(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_init_end"):
|
||||
trainer.optimizer.on_init_end(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_epoch_start(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_epoch_start"):
|
||||
trainer.model.module.on_epoch_start(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_epoch_start"):
|
||||
trainer.model.on_epoch_start(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_epoch_start"):
|
||||
trainer.criterion.on_epoch_start(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_epoch_start"):
|
||||
trainer.optimizer.on_epoch_start(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_epoch_end(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_epoch_end"):
|
||||
trainer.model.module.on_epoch_end(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_epoch_end"):
|
||||
trainer.model.on_epoch_end(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_epoch_end"):
|
||||
trainer.criterion.on_epoch_end(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_epoch_end"):
|
||||
trainer.optimizer.on_epoch_end(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_train_step_start(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_train_step_start"):
|
||||
trainer.model.module.on_train_step_start(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_train_step_start"):
|
||||
trainer.model.on_train_step_start(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_train_step_start"):
|
||||
trainer.criterion.on_train_step_start(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_train_step_start"):
|
||||
trainer.optimizer.on_train_step_start(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_train_step_end(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_train_step_end"):
|
||||
trainer.model.module.on_train_step_end(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_train_step_end"):
|
||||
trainer.model.on_train_step_end(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_train_step_end"):
|
||||
trainer.criterion.on_train_step_end(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_train_step_end"):
|
||||
trainer.optimizer.on_train_step_end(trainer)
|
||||
|
||||
@staticmethod
|
||||
def on_keyboard_interrupt(trainer) -> None:
|
||||
if hasattr(trainer.model, "module"):
|
||||
if hasattr(trainer.model.module, "on_keyboard_interrupt"):
|
||||
trainer.model.module.on_keyboard_interrupt(trainer)
|
||||
else:
|
||||
if hasattr(trainer.model, "on_keyboard_interrupt"):
|
||||
trainer.model.on_keyboard_interrupt(trainer)
|
||||
|
||||
if hasattr(trainer.criterion, "on_keyboard_interrupt"):
|
||||
trainer.criterion.on_keyboard_interrupt(trainer)
|
||||
|
||||
if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
|
||||
trainer.optimizer.on_keyboard_interrupt(trainer)
|
||||
@@ -0,0 +1,67 @@
|
||||
from typing import Generator
|
||||
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
|
||||
|
||||
class CapacitronOptimizer:
|
||||
"""Double optimizer class for the Capacitron model."""
|
||||
|
||||
def __init__(self, config: dict, model_params: Generator) -> None:
|
||||
self.primary_params, self.secondary_params = self.split_model_parameters(model_params)
|
||||
|
||||
optimizer_names = list(config.optimizer_params.keys())
|
||||
optimizer_parameters = list(config.optimizer_params.values())
|
||||
|
||||
self.primary_optimizer = get_optimizer(
|
||||
optimizer_names[0],
|
||||
optimizer_parameters[0],
|
||||
config.lr,
|
||||
parameters=self.primary_params,
|
||||
)
|
||||
|
||||
self.secondary_optimizer = get_optimizer(
|
||||
optimizer_names[1],
|
||||
self.extract_optimizer_parameters(optimizer_parameters[1]),
|
||||
optimizer_parameters[1]["lr"],
|
||||
parameters=self.secondary_params,
|
||||
)
|
||||
|
||||
self.param_groups = self.primary_optimizer.param_groups
|
||||
|
||||
def first_step(self):
|
||||
self.secondary_optimizer.step()
|
||||
self.secondary_optimizer.zero_grad()
|
||||
self.primary_optimizer.zero_grad()
|
||||
|
||||
def step(self):
|
||||
# Update param groups to display the correct learning rate
|
||||
self.param_groups = self.primary_optimizer.param_groups
|
||||
self.primary_optimizer.step()
|
||||
|
||||
def zero_grad(self, set_to_none=False):
|
||||
self.primary_optimizer.zero_grad(set_to_none)
|
||||
self.secondary_optimizer.zero_grad(set_to_none)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.primary_optimizer.load_state_dict(state_dict[0])
|
||||
self.secondary_optimizer.load_state_dict(state_dict[1])
|
||||
|
||||
def state_dict(self):
|
||||
return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()]
|
||||
|
||||
@staticmethod
|
||||
def split_model_parameters(model_params: Generator) -> list:
|
||||
primary_params = []
|
||||
secondary_params = []
|
||||
for name, param in model_params:
|
||||
if param.requires_grad:
|
||||
if name == "capacitron_vae_layer.beta":
|
||||
secondary_params.append(param)
|
||||
else:
|
||||
primary_params.append(param)
|
||||
return [iter(primary_params), iter(secondary_params)]
|
||||
|
||||
@staticmethod
|
||||
def extract_optimizer_parameters(params: dict) -> dict:
|
||||
"""Extract parameters that are not the learning rate"""
|
||||
return {k: v for k, v in params.items() if k != "lr"}
|
||||
@@ -0,0 +1,20 @@
|
||||
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def reduce_tensor(tensor, num_gpus):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
||||
rt /= num_gpus
|
||||
return rt
|
||||
|
||||
|
||||
def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
|
||||
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
||||
|
||||
# Set cuda device so everything is done on the right GPU.
|
||||
torch.cuda.set_device(rank % torch.cuda.device_count())
|
||||
|
||||
# Initialize distributed communication
|
||||
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
|
||||
@@ -0,0 +1,206 @@
|
||||
# Adapted from https://github.com/pytorch/audio/
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
import urllib
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from os.path import expanduser
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from torch.utils.model_zoo import tqdm
|
||||
|
||||
|
||||
def stream_url(
|
||||
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
|
||||
) -> Iterable:
|
||||
"""Stream url by chunk
|
||||
|
||||
Args:
|
||||
url (str): Url.
|
||||
start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
|
||||
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
|
||||
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||||
"""
|
||||
|
||||
# If we already have the whole file, there is no need to download it again
|
||||
req = urllib.request.Request(url, method="HEAD")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
url_size = int(response.info().get("Content-Length", -1))
|
||||
if url_size == start_byte:
|
||||
return
|
||||
|
||||
req = urllib.request.Request(url)
|
||||
if start_byte:
|
||||
req.headers["Range"] = "bytes={}-".format(start_byte)
|
||||
|
||||
with urllib.request.urlopen(req) as upointer, tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=url_size,
|
||||
disable=not progress_bar,
|
||||
) as pbar:
|
||||
num_bytes = 0
|
||||
while True:
|
||||
chunk = upointer.read(block_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
num_bytes += len(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
|
||||
def download_url(
|
||||
url: str,
|
||||
download_folder: str,
|
||||
filename: Optional[str] = None,
|
||||
hash_value: Optional[str] = None,
|
||||
hash_type: str = "sha256",
|
||||
progress_bar: bool = True,
|
||||
resume: bool = False,
|
||||
) -> None:
|
||||
"""Download file to disk.
|
||||
|
||||
Args:
|
||||
url (str): Url.
|
||||
download_folder (str): Folder to download file.
|
||||
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
|
||||
(Default: ``None``).
|
||||
hash_value (str or None, optional): Hash for url (Default: ``None``).
|
||||
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||||
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||||
resume (bool, optional): Enable resuming download (Default: ``False``).
|
||||
"""
|
||||
|
||||
req = urllib.request.Request(url, method="HEAD")
|
||||
req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with
|
||||
|
||||
# Detect filename
|
||||
filename = filename or req_info.get_filename() or os.path.basename(url)
|
||||
filepath = os.path.join(download_folder, filename)
|
||||
if resume and os.path.exists(filepath):
|
||||
mode = "ab"
|
||||
local_size: Optional[int] = os.path.getsize(filepath)
|
||||
|
||||
elif not resume and os.path.exists(filepath):
|
||||
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
|
||||
else:
|
||||
mode = "wb"
|
||||
local_size = None
|
||||
|
||||
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
|
||||
with open(filepath, "rb") as file_obj:
|
||||
if validate_file(file_obj, hash_value, hash_type):
|
||||
return
|
||||
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||||
|
||||
with open(filepath, mode) as fpointer:
|
||||
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
|
||||
fpointer.write(chunk)
|
||||
|
||||
with open(filepath, "rb") as file_obj:
|
||||
if hash_value and not validate_file(file_obj, hash_value, hash_type):
|
||||
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||||
|
||||
|
||||
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
|
||||
"""Validate a given file object with its hash.
|
||||
|
||||
Args:
|
||||
file_obj: File object to read from.
|
||||
hash_value (str): Hash for url.
|
||||
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||||
|
||||
Returns:
|
||||
bool: return True if its a valid file, else False.
|
||||
"""
|
||||
|
||||
if hash_type == "sha256":
|
||||
hash_func = hashlib.sha256()
|
||||
elif hash_type == "md5":
|
||||
hash_func = hashlib.md5()
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
while True:
|
||||
# Read by chunk to avoid filling memory
|
||||
chunk = file_obj.read(1024**2)
|
||||
if not chunk:
|
||||
break
|
||||
hash_func.update(chunk)
|
||||
|
||||
return hash_func.hexdigest() == hash_value
|
||||
|
||||
|
||||
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
|
||||
"""Extract archive.
|
||||
Args:
|
||||
from_path (str): the path of the archive.
|
||||
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
|
||||
(Default: ``None``)
|
||||
overwrite (bool, optional): overwrite existing files (Default: ``False``)
|
||||
|
||||
Returns:
|
||||
list: List of paths to extracted files even if not overwritten.
|
||||
"""
|
||||
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
try:
|
||||
with tarfile.open(from_path, "r") as tar:
|
||||
logging.info("Opened tar file %s.", from_path)
|
||||
files = []
|
||||
for file_ in tar: # type: Any
|
||||
file_path = os.path.join(to_path, file_.name)
|
||||
if file_.isfile():
|
||||
files.append(file_path)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("%s already extracted.", file_path)
|
||||
if not overwrite:
|
||||
continue
|
||||
tar.extract(file_, to_path)
|
||||
return files
|
||||
except tarfile.ReadError:
|
||||
pass
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(from_path, "r") as zfile:
|
||||
logging.info("Opened zip file %s.", from_path)
|
||||
files = zfile.namelist()
|
||||
for file_ in files:
|
||||
file_path = os.path.join(to_path, file_)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("%s already extracted.", file_path)
|
||||
if not overwrite:
|
||||
continue
|
||||
zfile.extract(file_, to_path)
|
||||
return files
|
||||
except zipfile.BadZipFile:
|
||||
pass
|
||||
|
||||
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
|
||||
|
||||
|
||||
def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str):
|
||||
"""Download dataset from kaggle.
|
||||
Args:
|
||||
dataset_path (str):
|
||||
This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning'
|
||||
dataset_name (str): Name of the folder the dataset will be saved in.
|
||||
output_path (str): Path of the location you want the dataset folder to be saved to.
|
||||
"""
|
||||
data_path = os.path.join(output_path, dataset_name)
|
||||
try:
|
||||
import kaggle # pylint: disable=import-outside-toplevel
|
||||
|
||||
kaggle.api.authenticate()
|
||||
print(f"""\nDownloading {dataset_name}...""")
|
||||
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
|
||||
except OSError:
|
||||
print(
|
||||
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
|
||||
|
||||
|
||||
def download_ljspeech(path: str):
|
||||
"""Download and extract LJSpeech dataset
|
||||
|
||||
Args:
|
||||
path (str): path to the directory where the dataset will be stored.
|
||||
"""
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
def download_vctk(path: str, use_kaggle: Optional[bool] = False):
|
||||
"""Download and extract VCTK dataset.
|
||||
|
||||
Args:
|
||||
path (str): path to the directory where the dataset will be stored.
|
||||
|
||||
use_kaggle (bool, optional): Downloads vctk dataset from kaggle. Is generally faster. Defaults to False.
|
||||
"""
|
||||
if use_kaggle:
|
||||
download_kaggle_dataset("mfekadu/english-multispeaker-corpus-for-voice-cloning", "VCTK", path)
|
||||
else:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
def download_tweb(path: str):
|
||||
"""Download and extract Tweb dataset
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory where the dataset will be stored.
|
||||
"""
|
||||
download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path)
|
||||
|
||||
|
||||
def download_libri_tts(path: str, subset: Optional[str] = "all"):
|
||||
"""Download and extract libri tts dataset.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory where the dataset will be stored.
|
||||
|
||||
subset (str, optional): Name of the subset to download. If you only want to download a certain
|
||||
portion specify it here. Defaults to 'all'.
|
||||
"""
|
||||
|
||||
subset_dict = {
|
||||
"libri-tts-clean-100": "http://www.openslr.org/resources/60/train-clean-100.tar.gz",
|
||||
"libri-tts-clean-360": "http://www.openslr.org/resources/60/train-clean-360.tar.gz",
|
||||
"libri-tts-other-500": "http://www.openslr.org/resources/60/train-other-500.tar.gz",
|
||||
"libri-tts-dev-clean": "http://www.openslr.org/resources/60/dev-clean.tar.gz",
|
||||
"libri-tts-dev-other": "http://www.openslr.org/resources/60/dev-other.tar.gz",
|
||||
"libri-tts-test-clean": "http://www.openslr.org/resources/60/test-clean.tar.gz",
|
||||
"libri-tts-test-other": "http://www.openslr.org/resources/60/test-other.tar.gz",
|
||||
}
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
if subset == "all":
|
||||
for sub, val in subset_dict.items():
|
||||
print(f" > Downloading {sub}...")
|
||||
download_url(val, path)
|
||||
basename = os.path.basename(val)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
print(" > All subsets downloaded")
|
||||
else:
|
||||
url = subset_dict[subset]
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
def download_thorsten_de(path: str):
|
||||
"""Download and extract Thorsten german male voice dataset.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory where the dataset will be stored.
|
||||
"""
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = "https://www.openslr.org/resources/95/thorsten-de_v02.tgz"
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
def download_mailabs(path: str, language: str = "english"):
|
||||
"""Download and extract Mailabs dataset.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory where the dataset will be stored.
|
||||
|
||||
language (str): Language subset to download. Defaults to english.
|
||||
"""
|
||||
language_dict = {
|
||||
"english": "https://data.solak.de/data/Training/stt_tts/en_US.tgz",
|
||||
"german": "https://data.solak.de/data/Training/stt_tts/de_DE.tgz",
|
||||
"french": "https://data.solak.de/data/Training/stt_tts/fr_FR.tgz",
|
||||
"italian": "https://data.solak.de/data/Training/stt_tts/it_IT.tgz",
|
||||
"spanish": "https://data.solak.de/data/Training/stt_tts/es_ES.tgz",
|
||||
}
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = language_dict[language]
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
@@ -0,0 +1,239 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
||||
|
||||
def to_cuda(x: torch.Tensor) -> torch.Tensor:
|
||||
if x is None:
|
||||
return None
|
||||
if torch.is_tensor(x):
|
||||
x = x.contiguous()
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda(non_blocking=True)
|
||||
return x
|
||||
|
||||
|
||||
def get_cuda():
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
return use_cuda, device
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
try:
|
||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||
current = next(line for line in out.split("\n") if line.startswith("*"))
|
||||
current.replace("* ", "")
|
||||
except subprocess.CalledProcessError:
|
||||
current = "inside_docker"
|
||||
except (FileNotFoundError, StopIteration) as e:
|
||||
current = "unknown"
|
||||
return current
|
||||
|
||||
|
||||
def get_commit_hash():
|
||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||
# try:
|
||||
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||
# 'HEAD']) # Verify client is clean
|
||||
# except:
|
||||
# raise RuntimeError(
|
||||
# " !! Commit before training to get the commit hash.")
|
||||
try:
|
||||
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
|
||||
# Not copying .git folder into docker container
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
commit = "0000000"
|
||||
return commit
|
||||
|
||||
|
||||
def get_experiment_folder_path(root_path, model_name):
|
||||
"""Get an experiment folder path with the current date and time"""
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||
return output_folder
|
||||
|
||||
|
||||
def remove_experiment_folder(experiment_path):
|
||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||
fs = fsspec.get_mapper(experiment_path).fs
|
||||
checkpoint_files = fs.glob(experiment_path + "/*.pth")
|
||||
if not checkpoint_files:
|
||||
if fs.exists(experiment_path):
|
||||
fs.rm(experiment_path, recursive=True)
|
||||
print(" ! Run is removed from {}".format(experiment_path))
|
||||
else:
|
||||
print(" ! Run is kept in {}".format(experiment_path))
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
text = text.replace("Tts", "TTS")
|
||||
text = text.replace("vc", "VC")
|
||||
return text
|
||||
|
||||
|
||||
def find_module(module_path: str, module_name: str) -> object:
|
||||
module_name = module_name.lower()
|
||||
module = importlib.import_module(module_path + "." + module_name)
|
||||
class_name = to_camel(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_class(module_path: str) -> object:
|
||||
"""Import a class from a module path.
|
||||
|
||||
Args:
|
||||
module_path (str): The module path of the class.
|
||||
|
||||
Returns:
|
||||
object: The imported class.
|
||||
"""
|
||||
class_name = module_path.split(".")[-1]
|
||||
module_path = ".".join(module_path.split(".")[:-1])
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def get_import_path(obj: object) -> str:
|
||||
"""Get the import path of a class.
|
||||
|
||||
Args:
|
||||
obj (object): The class object.
|
||||
|
||||
Returns:
|
||||
str: The import path of the class.
|
||||
"""
|
||||
return ".".join([type(obj).__module__, type(obj).__name__])
|
||||
|
||||
|
||||
def get_user_data_dir(appname):
|
||||
TTS_HOME = os.environ.get("TTS_HOME")
|
||||
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
|
||||
if TTS_HOME is not None:
|
||||
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
|
||||
elif XDG_DATA_HOME is not None:
|
||||
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
|
||||
elif sys.platform == "win32":
|
||||
import winreg # pylint: disable=import-outside-toplevel
|
||||
|
||||
key = winreg.OpenKey(
|
||||
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||
)
|
||||
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
|
||||
ans = Path(dir_).resolve(strict=False)
|
||||
elif sys.platform == "darwin":
|
||||
ans = Path("~/Library/Application Support/").expanduser()
|
||||
else:
|
||||
ans = Path.home().joinpath(".local/share")
|
||||
return ans.joinpath(appname)
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
||||
# 3. skip reinit layers
|
||||
if c.has("reinit_layers") and c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
|
||||
"""Format kwargs to hande auxilary inputs to models.
|
||||
|
||||
Args:
|
||||
def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`.
|
||||
kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model.
|
||||
|
||||
Returns:
|
||||
Dict: arguments with formatted auxilary inputs.
|
||||
"""
|
||||
kwargs = kwargs.copy()
|
||||
for name in def_args:
|
||||
if name not in kwargs or kwargs[name] is None:
|
||||
kwargs[name] = def_args[name]
|
||||
return kwargs
|
||||
|
||||
|
||||
class KeepAverage:
|
||||
def __init__(self):
|
||||
self.avg_values = {}
|
||||
self.iters = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.avg_values[key]
|
||||
|
||||
def items(self):
|
||||
return self.avg_values.items()
|
||||
|
||||
def add_value(self, name, init_val=0, init_iter=0):
|
||||
self.avg_values[name] = init_val
|
||||
self.iters[name] = init_iter
|
||||
|
||||
def update_value(self, name, value, weighted_avg=False):
|
||||
if name not in self.avg_values:
|
||||
# add value if not exist before
|
||||
self.add_value(name, init_val=value)
|
||||
else:
|
||||
# else update existing value
|
||||
if weighted_avg:
|
||||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||
self.iters[name] += 1
|
||||
else:
|
||||
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
|
||||
self.iters[name] += 1
|
||||
self.avg_values[name] /= self.iters[name]
|
||||
|
||||
def add_values(self, name_dict):
|
||||
for key, value in name_dict.items():
|
||||
self.add_value(key, init_val=value)
|
||||
|
||||
def update_values(self, value_dict):
|
||||
for key, value in value_dict.items():
|
||||
self.update_value(key, value)
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
|
||||
|
||||
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
|
||||
lg = logging.getLogger(logger_name)
|
||||
formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S")
|
||||
lg.setLevel(level)
|
||||
if tofile:
|
||||
log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp()))
|
||||
fh = logging.FileHandler(log_file, mode="w")
|
||||
fh.setFormatter(formatter)
|
||||
lg.addHandler(fh)
|
||||
if screen:
|
||||
sh = logging.StreamHandler()
|
||||
sh.setFormatter(formatter)
|
||||
lg.addHandler(sh)
|
||||
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
import pickle as pickle_tts
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
||||
|
||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||
"""Overload default pickler to solve module renaming problem"""
|
||||
|
||||
def find_class(self, module, name):
|
||||
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
"""A custom dict which converts dict keys
|
||||
to class attributes"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def load_fsspec(
|
||||
path: str,
|
||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||
cache: bool = True,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
path: Any path or url supported by fsspec.
|
||||
map_location: torch.device or str.
|
||||
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
|
||||
**kwargs: Keyword arguments forwarded to torch.load.
|
||||
|
||||
Returns:
|
||||
Object stored in path.
|
||||
"""
|
||||
is_local = os.path.isdir(path) or os.path.isfile(path)
|
||||
if cache and not is_local:
|
||||
with fsspec.open(
|
||||
f"filecache::{path}",
|
||||
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
|
||||
mode="rb",
|
||||
) as f:
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
else:
|
||||
with fsspec.open(path, "rb") as f:
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
model, checkpoint_path, use_cuda=False, eval=False, cache=False
|
||||
): # pylint: disable=redefined-builtin
|
||||
try:
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
|
||||
model.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
if eval:
|
||||
model.eval()
|
||||
return model, state
|
||||
@@ -0,0 +1,621 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tarfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from shutil import copyfile, rmtree
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import fsspec
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config, read_json_with_comments
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
||||
LICENSE_URLS = {
|
||||
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
|
||||
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||
"mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||
"mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||
"mit": "https://choosealicense.com/licenses/mit/",
|
||||
"apache 2.0": "https://choosealicense.com/licenses/apache-2.0/",
|
||||
"apache2": "https://choosealicense.com/licenses/apache-2.0/",
|
||||
"cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/",
|
||||
"cpml": "https://coqui.ai/cpml.txt",
|
||||
}
|
||||
|
||||
|
||||
class ModelManager(object):
|
||||
tqdm_progress = None
|
||||
"""Manage TTS models defined in .models.json.
|
||||
It provides an interface to list and download
|
||||
models defines in '.model.json'
|
||||
|
||||
Models are downloaded under '.TTS' folder in the user's
|
||||
home path.
|
||||
|
||||
Args:
|
||||
models_file (str): path to .model.json file. Defaults to None.
|
||||
output_prefix (str): prefix to `tts` to download models. Defaults to None
|
||||
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
|
||||
verbose (bool): print info. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True):
|
||||
super().__init__()
|
||||
self.progress_bar = progress_bar
|
||||
self.verbose = verbose
|
||||
if output_prefix is None:
|
||||
self.output_prefix = get_user_data_dir("tts")
|
||||
else:
|
||||
self.output_prefix = os.path.join(output_prefix, "tts")
|
||||
self.models_dict = None
|
||||
if models_file is not None:
|
||||
self.read_models_file(models_file)
|
||||
else:
|
||||
# try the default location
|
||||
path = Path(__file__).parent / "../.models.json"
|
||||
self.read_models_file(path)
|
||||
|
||||
def read_models_file(self, file_path):
|
||||
"""Read .models.json as a dict
|
||||
|
||||
Args:
|
||||
file_path (str): path to .models.json.
|
||||
"""
|
||||
self.models_dict = read_json_with_comments(file_path)
|
||||
|
||||
def _list_models(self, model_type, model_count=0):
|
||||
if self.verbose:
|
||||
print("\n Name format: type/language/dataset/model")
|
||||
model_list = []
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
for model in self.models_dict[model_type][lang][dataset]:
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if self.verbose:
|
||||
if os.path.exists(output_path):
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||
else:
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
|
||||
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||
model_count += 1
|
||||
return model_list
|
||||
|
||||
def _list_for_model_type(self, model_type):
|
||||
models_name_list = []
|
||||
model_count = 1
|
||||
models_name_list.extend(self._list_models(model_type, model_count))
|
||||
return models_name_list
|
||||
|
||||
def list_models(self):
|
||||
models_name_list = []
|
||||
model_count = 1
|
||||
for model_type in self.models_dict:
|
||||
model_list = self._list_models(model_type, model_count)
|
||||
models_name_list.extend(model_list)
|
||||
return models_name_list
|
||||
|
||||
def model_info_by_idx(self, model_query):
|
||||
"""Print the description of the model from .models.json file using model_idx
|
||||
|
||||
Args:
|
||||
model_query (str): <model_tye>/<model_idx>
|
||||
"""
|
||||
model_name_list = []
|
||||
model_type, model_query_idx = model_query.split("/")
|
||||
try:
|
||||
model_query_idx = int(model_query_idx)
|
||||
if model_query_idx <= 0:
|
||||
print("> model_query_idx should be a positive integer!")
|
||||
return
|
||||
except:
|
||||
print("> model_query_idx should be an integer!")
|
||||
return
|
||||
model_count = 0
|
||||
if model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
for model in self.models_dict[model_type][lang][dataset]:
|
||||
model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||
model_count += 1
|
||||
else:
|
||||
print(f"> model_type {model_type} does not exist in the list.")
|
||||
return
|
||||
if model_query_idx > model_count:
|
||||
print(f"model query idx exceeds the number of available models [{model_count}] ")
|
||||
else:
|
||||
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
|
||||
print(f"> model type : {model_type}")
|
||||
print(f"> language supported : {lang}")
|
||||
print(f"> dataset used : {dataset}")
|
||||
print(f"> model name : {model}")
|
||||
if "description" in self.models_dict[model_type][lang][dataset][model]:
|
||||
print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
|
||||
else:
|
||||
print("> description : coming soon")
|
||||
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
|
||||
print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")
|
||||
|
||||
def model_info_by_full_name(self, model_query_name):
|
||||
"""Print the description of the model from .models.json file using model_full_name
|
||||
|
||||
Args:
|
||||
model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
|
||||
"""
|
||||
model_type, lang, dataset, model = model_query_name.split("/")
|
||||
if model_type in self.models_dict:
|
||||
if lang in self.models_dict[model_type]:
|
||||
if dataset in self.models_dict[model_type][lang]:
|
||||
if model in self.models_dict[model_type][lang][dataset]:
|
||||
print(f"> model type : {model_type}")
|
||||
print(f"> language supported : {lang}")
|
||||
print(f"> dataset used : {dataset}")
|
||||
print(f"> model name : {model}")
|
||||
if "description" in self.models_dict[model_type][lang][dataset][model]:
|
||||
print(
|
||||
f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
|
||||
)
|
||||
else:
|
||||
print("> description : coming soon")
|
||||
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
|
||||
print(
|
||||
f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
|
||||
)
|
||||
else:
|
||||
print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
|
||||
else:
|
||||
print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
|
||||
else:
|
||||
print(f"> lang {lang} does not exist for {model_type}.")
|
||||
else:
|
||||
print(f"> model_type {model_type} does not exist in the list.")
|
||||
|
||||
def list_tts_models(self):
|
||||
"""Print all `TTS` models and return a list of model names
|
||||
|
||||
Format is `language/dataset/model`
|
||||
"""
|
||||
return self._list_for_model_type("tts_models")
|
||||
|
||||
def list_vocoder_models(self):
|
||||
"""Print all the `vocoder` models and return a list of model names
|
||||
|
||||
Format is `language/dataset/model`
|
||||
"""
|
||||
return self._list_for_model_type("vocoder_models")
|
||||
|
||||
def list_vc_models(self):
|
||||
"""Print all the voice conversion models and return a list of model names
|
||||
|
||||
Format is `language/dataset/model`
|
||||
"""
|
||||
return self._list_for_model_type("voice_conversion_models")
|
||||
|
||||
def list_langs(self):
|
||||
"""Print all the available languages"""
|
||||
print(" Name format: type/language")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
print(f" >: {model_type}/{lang} ")
|
||||
|
||||
def list_datasets(self):
|
||||
"""Print all the datasets"""
|
||||
print(" Name format: type/language/dataset")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
print(f" >: {model_type}/{lang}/{dataset}")
|
||||
|
||||
@staticmethod
|
||||
def print_model_license(model_item: Dict):
|
||||
"""Print the license of a model
|
||||
|
||||
Args:
|
||||
model_item (dict): model item in the models.json
|
||||
"""
|
||||
if "license" in model_item and model_item["license"].strip() != "":
|
||||
print(f" > Model's license - {model_item['license']}")
|
||||
if model_item["license"].lower() in LICENSE_URLS:
|
||||
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
|
||||
else:
|
||||
print(" > Check https://opensource.org/licenses for more info.")
|
||||
else:
|
||||
print(" > Model's license - No license information available")
|
||||
|
||||
def _download_github_model(self, model_item: Dict, output_path: str):
|
||||
if isinstance(model_item["github_rls_url"], list):
|
||||
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
else:
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
|
||||
def _download_hf_model(self, model_item: Dict, output_path: str):
|
||||
if isinstance(model_item["hf_url"], list):
|
||||
self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
|
||||
else:
|
||||
self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
|
||||
|
||||
def download_fairseq_model(self, model_name, output_path):
|
||||
URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/"
|
||||
_, lang, _, _ = model_name.split("/")
|
||||
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
|
||||
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
|
||||
|
||||
@staticmethod
|
||||
def set_model_url(model_item: Dict):
|
||||
model_item["model_url"] = None
|
||||
if "github_rls_url" in model_item:
|
||||
model_item["model_url"] = model_item["github_rls_url"]
|
||||
elif "hf_url" in model_item:
|
||||
model_item["model_url"] = model_item["hf_url"]
|
||||
elif "fairseq" in model_item["model_name"]:
|
||||
model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
|
||||
elif "xtts" in model_item["model_name"]:
|
||||
model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
|
||||
return model_item
|
||||
|
||||
def _set_model_item(self, model_name):
|
||||
# fetch model info from the dict
|
||||
if "fairseq" in model_name:
|
||||
model_type = "tts_models"
|
||||
lang = model_name.split("/")[1]
|
||||
model_item = {
|
||||
"model_type": "tts_models",
|
||||
"license": "CC BY-NC 4.0",
|
||||
"default_vocoder": None,
|
||||
"author": "fairseq",
|
||||
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
|
||||
}
|
||||
model_item["model_name"] = model_name
|
||||
elif "xtts" in model_name and len(model_name.split("/")) != 4:
|
||||
# loading xtts models with only model name (e.g. xtts_v2.0.2)
|
||||
# check model name has the version number with regex
|
||||
version_regex = r"v\d+\.\d+\.\d+"
|
||||
if re.search(version_regex, model_name):
|
||||
model_version = model_name.split("_")[-1]
|
||||
else:
|
||||
model_version = "main"
|
||||
model_type = "tts_models"
|
||||
lang = "multilingual"
|
||||
dataset = "multi-dataset"
|
||||
model = model_name
|
||||
model_item = {
|
||||
"default_vocoder": None,
|
||||
"license": "CPML",
|
||||
"contact": "info@coqui.ai",
|
||||
"tos_required": True,
|
||||
"hf_url": [
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5",
|
||||
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth",
|
||||
],
|
||||
}
|
||||
else:
|
||||
# get model from models.json
|
||||
model_type, lang, dataset, model = model_name.split("/")
|
||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||
model_item["model_type"] = model_type
|
||||
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
|
||||
model_item = self.set_model_url(model_item)
|
||||
return model_item, model_full_name, model, md5hash
|
||||
|
||||
@staticmethod
|
||||
def ask_tos(model_full_path):
|
||||
"""Ask the user to agree to the terms of service"""
|
||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||
print(" > You must confirm the following:")
|
||||
print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"')
|
||||
print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]')
|
||||
answer = input(" | | > ")
|
||||
if answer.lower() == "y":
|
||||
with open(tos_path, "w", encoding="utf-8") as f:
|
||||
f.write("I have read, understood and agreed to the Terms and Conditions.")
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def tos_agreed(model_item, model_full_path):
|
||||
"""Check if the user has agreed to the terms of service"""
|
||||
if "tos_required" in model_item and model_item["tos_required"]:
|
||||
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
|
||||
if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1":
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def create_dir_and_download_model(self, model_name, model_item, output_path):
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
# handle TOS
|
||||
if not self.tos_agreed(model_item, output_path):
|
||||
if not self.ask_tos(output_path):
|
||||
os.rmdir(output_path)
|
||||
raise Exception(" [!] You must agree to the terms of service to use this model.")
|
||||
print(f" > Downloading model to {output_path}")
|
||||
try:
|
||||
if "fairseq" in model_name:
|
||||
self.download_fairseq_model(model_name, output_path)
|
||||
elif "github_rls_url" in model_item:
|
||||
self._download_github_model(model_item, output_path)
|
||||
elif "hf_url" in model_item:
|
||||
self._download_hf_model(model_item, output_path)
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f" > Failed to download the model file to {output_path}")
|
||||
rmtree(output_path)
|
||||
raise e
|
||||
self.print_model_license(model_item=model_item)
|
||||
|
||||
def check_if_configs_are_equal(self, model_name, model_item, output_path):
|
||||
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
|
||||
config_local = json.load(f)
|
||||
remote_url = None
|
||||
for url in model_item["hf_url"]:
|
||||
if "config.json" in url:
|
||||
remote_url = url
|
||||
break
|
||||
|
||||
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
|
||||
config_remote = json.load(f)
|
||||
|
||||
if not config_local == config_remote:
|
||||
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
Model name is in the format
|
||||
'type/language/dataset/model'
|
||||
e.g. 'tts_model/en/ljspeech/tacotron'
|
||||
|
||||
Every model must have the following files:
|
||||
- *.pth : pytorch model checkpoint file.
|
||||
- config.json : model config file.
|
||||
- scale_stats.npy (if exist): scale values for preprocessing.
|
||||
|
||||
Args:
|
||||
model_name (str): model name as explained above.
|
||||
"""
|
||||
model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
|
||||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
if md5sum is not None:
|
||||
md5sum_file = os.path.join(output_path, "hash.md5")
|
||||
if os.path.isfile(md5sum_file):
|
||||
with open(md5sum_file, mode="r") as f:
|
||||
if not f.read() == md5sum:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
# if the configs are different, redownload it
|
||||
# ToDo: we need a better way to handle it
|
||||
if "xtts" in model_name:
|
||||
try:
|
||||
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
# find downloaded files
|
||||
output_model_path = output_path
|
||||
output_config_path = None
|
||||
if (
|
||||
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
|
||||
): # TODO:This is stupid but don't care for now.
|
||||
output_model_path, output_config_path = self._find_files(output_path)
|
||||
# update paths in the config.json
|
||||
self._update_paths(output_path, output_config_path)
|
||||
return output_model_path, output_config_path, model_item
|
||||
|
||||
@staticmethod
|
||||
def _find_files(output_path: str) -> Tuple[str, str]:
|
||||
"""Find the model and config files in the output path
|
||||
|
||||
Args:
|
||||
output_path (str): path to the model files
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: path to the model file and config file
|
||||
"""
|
||||
model_file = None
|
||||
config_file = None
|
||||
for file_name in os.listdir(output_path):
|
||||
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
|
||||
model_file = os.path.join(output_path, file_name)
|
||||
elif file_name == "config.json":
|
||||
config_file = os.path.join(output_path, file_name)
|
||||
if model_file is None:
|
||||
raise ValueError(" [!] Model file not found in the output path")
|
||||
if config_file is None:
|
||||
raise ValueError(" [!] Config file not found in the output path")
|
||||
return model_file, config_file
|
||||
|
||||
@staticmethod
|
||||
def _find_speaker_encoder(output_path: str) -> str:
|
||||
"""Find the speaker encoder file in the output path
|
||||
|
||||
Args:
|
||||
output_path (str): path to the model files
|
||||
|
||||
Returns:
|
||||
str: path to the speaker encoder file
|
||||
"""
|
||||
speaker_encoder_file = None
|
||||
for file_name in os.listdir(output_path):
|
||||
if file_name in ["model_se.pth", "model_se.pth.tar"]:
|
||||
speaker_encoder_file = os.path.join(output_path, file_name)
|
||||
return speaker_encoder_file
|
||||
|
||||
def _update_paths(self, output_path: str, config_path: str) -> None:
|
||||
"""Update paths for certain files in config.json after download.
|
||||
|
||||
Args:
|
||||
output_path (str): local path the model is downloaded to.
|
||||
config_path (str): local config.json path.
|
||||
"""
|
||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
||||
output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth")
|
||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
||||
output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth")
|
||||
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
|
||||
speaker_encoder_model_path = self._find_speaker_encoder(output_path)
|
||||
|
||||
# update the scale_path.npy file path in the model config.json
|
||||
self._update_path("audio.stats_path", output_stats_path, config_path)
|
||||
|
||||
# update the speakers.json file path in the model config.json to the current path
|
||||
self._update_path("d_vector_file", output_d_vector_file_path, config_path)
|
||||
self._update_path("d_vector_file", output_d_vector_file_pth_path, config_path)
|
||||
self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
|
||||
self._update_path("model_args.d_vector_file", output_d_vector_file_pth_path, config_path)
|
||||
|
||||
# update the speaker_ids.json file path in the model config.json to the current path
|
||||
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
||||
self._update_path("speakers_file", output_speaker_ids_file_pth_path, config_path)
|
||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_pth_path, config_path)
|
||||
|
||||
# update the speaker_encoder file path in the model config.json to the current path
|
||||
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||
|
||||
@staticmethod
|
||||
def _update_path(field_name, new_path, config_path):
|
||||
"""Update the path in the model config.json for the current environment after download"""
|
||||
if new_path and os.path.exists(new_path):
|
||||
config = load_config(config_path)
|
||||
field_names = field_name.split(".")
|
||||
if len(field_names) > 1:
|
||||
# field name points to a sub-level field
|
||||
sub_conf = config
|
||||
for fd in field_names[:-1]:
|
||||
if fd in sub_conf:
|
||||
sub_conf = sub_conf[fd]
|
||||
else:
|
||||
return
|
||||
if isinstance(sub_conf[field_names[-1]], list):
|
||||
sub_conf[field_names[-1]] = [new_path]
|
||||
else:
|
||||
sub_conf[field_names[-1]] = new_path
|
||||
else:
|
||||
# field name points to a top-level field
|
||||
if not field_name in config:
|
||||
return
|
||||
if isinstance(config[field_name], list):
|
||||
config[field_name] = [new_path]
|
||||
else:
|
||||
config[field_name] = new_path
|
||||
config.save_json(config_path)
|
||||
|
||||
@staticmethod
|
||||
def _download_zip_file(file_url, output_folder, progress_bar):
|
||||
"""Download the github releases"""
|
||||
# download the file
|
||||
r = requests.get(file_url, stream=True)
|
||||
# extract the file
|
||||
try:
|
||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||
with open(temp_zip_name, "wb") as file:
|
||||
for data in r.iter_content(block_size):
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress.update(len(data))
|
||||
file.write(data)
|
||||
with zipfile.ZipFile(temp_zip_name) as z:
|
||||
z.extractall(output_folder)
|
||||
os.remove(temp_zip_name) # delete zip after extract
|
||||
except zipfile.BadZipFile:
|
||||
print(f" > Error: Bad zip file - {file_url}")
|
||||
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
||||
# move the files to the outer path
|
||||
for file_path in z.namelist():
|
||||
src_path = os.path.join(output_folder, file_path)
|
||||
if os.path.isfile(src_path):
|
||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||
if src_path != dst_path:
|
||||
copyfile(src_path, dst_path)
|
||||
# remove redundant (hidden or not) folders
|
||||
for file_path in z.namelist():
|
||||
if os.path.isdir(os.path.join(output_folder, file_path)):
|
||||
rmtree(os.path.join(output_folder, file_path))
|
||||
|
||||
@staticmethod
|
||||
def _download_tar_file(file_url, output_folder, progress_bar):
|
||||
"""Download the github releases"""
|
||||
# download the file
|
||||
r = requests.get(file_url, stream=True)
|
||||
# extract the file
|
||||
try:
|
||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||
with open(temp_tar_name, "wb") as file:
|
||||
for data in r.iter_content(block_size):
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress.update(len(data))
|
||||
file.write(data)
|
||||
with tarfile.open(temp_tar_name) as t:
|
||||
t.extractall(output_folder)
|
||||
tar_names = t.getnames()
|
||||
os.remove(temp_tar_name) # delete tar after extract
|
||||
except tarfile.ReadError:
|
||||
print(f" > Error: Bad tar file - {file_url}")
|
||||
raise tarfile.ReadError # pylint: disable=raise-missing-from
|
||||
# move the files to the outer path
|
||||
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
|
||||
src_path = os.path.join(output_folder, tar_names[0], file_path)
|
||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||
if src_path != dst_path:
|
||||
copyfile(src_path, dst_path)
|
||||
# remove the extracted folder
|
||||
rmtree(os.path.join(output_folder, tar_names[0]))
|
||||
|
||||
@staticmethod
|
||||
def _download_model_files(file_urls, output_folder, progress_bar):
|
||||
"""Download the github releases"""
|
||||
for file_url in file_urls:
|
||||
# download the file
|
||||
r = requests.get(file_url, stream=True)
|
||||
# extract the file
|
||||
bease_filename = file_url.split("/")[-1]
|
||||
temp_zip_name = os.path.join(output_folder, bease_filename)
|
||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
with open(temp_zip_name, "wb") as file:
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
for data in r.iter_content(block_size):
|
||||
if progress_bar:
|
||||
ModelManager.tqdm_progress.update(len(data))
|
||||
file.write(data)
|
||||
|
||||
@staticmethod
|
||||
def _check_dict_key(my_dict, key):
|
||||
if key in my_dict.keys() and my_dict[key] is not None:
|
||||
if not isinstance(key, str):
|
||||
return True
|
||||
if isinstance(key, str) and len(my_dict[key]) > 0:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,105 @@
|
||||
# modified from https://github.com/LiyuanLucasLiu/RAdam
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if eps < 0.0:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
|
||||
self.degenerated_to_sgd = degenerated_to_sgd
|
||||
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
|
||||
for param in params:
|
||||
if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]):
|
||||
param["buffer"] = [[None, None, None] for _ in range(10)]
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state): # pylint: disable=useless-super-delegation
|
||||
super().__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("RAdam does not support sparse gradients")
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
|
||||
state["step"] += 1
|
||||
buffered = group["buffer"][int(state["step"] % 10)]
|
||||
if state["step"] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state["step"]
|
||||
beta2_t = beta2 ** state["step"]
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = math.sqrt(
|
||||
(1 - beta2_t)
|
||||
* (N_sma - 4)
|
||||
/ (N_sma_max - 4)
|
||||
* (N_sma - 2)
|
||||
/ N_sma
|
||||
* N_sma_max
|
||||
/ (N_sma_max - 2)
|
||||
) / (1 - beta1 ** state["step"])
|
||||
elif self.degenerated_to_sgd:
|
||||
step_size = 1.0 / (1 - beta1 ** state["step"])
|
||||
else:
|
||||
step_size = -1
|
||||
buffered[2] = step_size
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
|
||||
p.data.copy_(p_data_fp32)
|
||||
elif step_size > 0:
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
|
||||
p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
@@ -0,0 +1,201 @@
|
||||
import math
|
||||
import random
|
||||
from typing import Callable, List, Union
|
||||
|
||||
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
|
||||
|
||||
|
||||
class SubsetSampler(Sampler):
|
||||
"""
|
||||
Samples elements sequentially from a given list of indices.
|
||||
|
||||
Args:
|
||||
indices (list): a sequence of indices
|
||||
"""
|
||||
|
||||
def __init__(self, indices):
|
||||
super().__init__(indices)
|
||||
self.indices = indices
|
||||
|
||||
def __iter__(self):
|
||||
return (self.indices[i] for i in range(len(self.indices)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class PerfectBatchSampler(Sampler):
|
||||
"""
|
||||
Samples a mini-batch of indices for a balanced class batching
|
||||
|
||||
Args:
|
||||
dataset_items(list): dataset items to sample from.
|
||||
classes (list): list of classes of dataset_items to sample from.
|
||||
batch_size (int): total number of samples to be sampled in a mini-batch.
|
||||
num_gpus (int): number of GPU in the data parallel mode.
|
||||
shuffle (bool): if True, samples randomly, otherwise samples sequentially.
|
||||
drop_last (bool): if True, drops last incomplete batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_items,
|
||||
classes,
|
||||
batch_size,
|
||||
num_classes_in_batch,
|
||||
num_gpus=1,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
label_key="class_name",
|
||||
):
|
||||
super().__init__(dataset_items)
|
||||
assert (
|
||||
batch_size % (num_classes_in_batch * num_gpus) == 0
|
||||
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
|
||||
|
||||
label_indices = {}
|
||||
for idx, item in enumerate(dataset_items):
|
||||
label = item[label_key]
|
||||
if label not in label_indices.keys():
|
||||
label_indices[label] = [idx]
|
||||
else:
|
||||
label_indices[label].append(idx)
|
||||
|
||||
if shuffle:
|
||||
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
|
||||
else:
|
||||
self._samplers = [SubsetSampler(label_indices[key]) for key in classes]
|
||||
|
||||
self._batch_size = batch_size
|
||||
self._drop_last = drop_last
|
||||
self._dp_devices = num_gpus
|
||||
self._num_classes_in_batch = num_classes_in_batch
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
if self._num_classes_in_batch != len(self._samplers):
|
||||
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||
else:
|
||||
valid_samplers_idx = None
|
||||
|
||||
iters = [iter(s) for s in self._samplers]
|
||||
done = False
|
||||
|
||||
while True:
|
||||
b = []
|
||||
for i, it in enumerate(iters):
|
||||
if valid_samplers_idx is not None and i not in valid_samplers_idx:
|
||||
continue
|
||||
idx = next(it, None)
|
||||
if idx is None:
|
||||
done = True
|
||||
break
|
||||
b.append(idx)
|
||||
if done:
|
||||
break
|
||||
batch += b
|
||||
if len(batch) == self._batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if valid_samplers_idx is not None:
|
||||
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||
|
||||
if not self._drop_last:
|
||||
if len(batch) > 0:
|
||||
groups = len(batch) // self._num_classes_in_batch
|
||||
if groups % self._dp_devices == 0:
|
||||
yield batch
|
||||
else:
|
||||
batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
class_batch_size = self._batch_size // self._num_classes_in_batch
|
||||
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
class SortedSampler(Sampler):
|
||||
"""Samples elements sequentially, always in the same order.
|
||||
|
||||
Taken from https://github.com/PetrochukM/PyTorch-NLP
|
||||
|
||||
Args:
|
||||
data (iterable): Iterable data.
|
||||
sort_key (callable): Specifies a function of one argument that is used to extract a
|
||||
numerical comparison key from each list element.
|
||||
|
||||
Example:
|
||||
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
|
||||
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data, sort_key: Callable = identity):
|
||||
super().__init__(data)
|
||||
self.data = data
|
||||
self.sort_key = sort_key
|
||||
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
|
||||
zip_ = sorted(zip_, key=lambda r: r[1])
|
||||
self.sorted_indexes = [item[0] for item in zip_]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.sorted_indexes)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class BucketBatchSampler(BatchSampler):
|
||||
"""Bucket batch sampler
|
||||
|
||||
Adapted from https://github.com/PetrochukM/PyTorch-NLP
|
||||
|
||||
Args:
|
||||
sampler (torch.data.utils.sampler.Sampler):
|
||||
batch_size (int): Size of mini-batch.
|
||||
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
|
||||
than `batch_size`.
|
||||
data (list): List of data samples.
|
||||
sort_key (callable, optional): Callable to specify a comparison key for sorting.
|
||||
bucket_size_multiplier (int, optional): Buckets are of size
|
||||
`batch_size * bucket_size_multiplier`.
|
||||
|
||||
Example:
|
||||
>>> sampler = WeightedRandomSampler(weights, len(weights))
|
||||
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampler,
|
||||
data,
|
||||
batch_size,
|
||||
drop_last,
|
||||
sort_key: Union[Callable, List] = identity,
|
||||
bucket_size_multiplier=100,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.data = data
|
||||
self.sort_key = sort_key
|
||||
_bucket_size = batch_size * bucket_size_multiplier
|
||||
if hasattr(sampler, "__len__"):
|
||||
_bucket_size = min(_bucket_size, len(sampler))
|
||||
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
|
||||
|
||||
def __iter__(self):
|
||||
for idxs in self.bucket_sampler:
|
||||
bucket_data = [self.data[idx] for idx in idxs]
|
||||
sorted_sampler = SortedSampler(bucket_data, self.sort_key)
|
||||
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
|
||||
sorted_idxs = [idxs[i] for i in batch_idx]
|
||||
yield sorted_idxs
|
||||
|
||||
def __len__(self):
|
||||
if self.drop_last:
|
||||
return len(self.sampler) // self.batch_size
|
||||
return math.ceil(len(self.sampler) / self.batch_size)
|
||||
@@ -0,0 +1,505 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pysbd
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.models.vits import Vits
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import save_wav
|
||||
from TTS.vc.models import setup_model as setup_vc_model
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||
|
||||
|
||||
class Synthesizer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tts_checkpoint: str = "",
|
||||
tts_config_path: str = "",
|
||||
tts_speakers_file: str = "",
|
||||
tts_languages_file: str = "",
|
||||
vocoder_checkpoint: str = "",
|
||||
vocoder_config: str = "",
|
||||
encoder_checkpoint: str = "",
|
||||
encoder_config: str = "",
|
||||
vc_checkpoint: str = "",
|
||||
vc_config: str = "",
|
||||
model_dir: str = "",
|
||||
voice_dir: str = None,
|
||||
use_cuda: bool = False,
|
||||
) -> None:
|
||||
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
||||
model and synthesize speech from the provided text.
|
||||
|
||||
The text is divided into a list of sentences using `pysbd` and synthesize
|
||||
speech on each sentence separately.
|
||||
|
||||
If you have certain special characters in your text, you need to handle
|
||||
them before providing the text to Synthesizer.
|
||||
|
||||
TODO: set the segmenter based on the source language
|
||||
|
||||
Args:
|
||||
tts_checkpoint (str, optional): path to the tts model file.
|
||||
tts_config_path (str, optional): path to the tts config file.
|
||||
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
||||
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
||||
encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`,
|
||||
encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`,
|
||||
vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`,
|
||||
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
|
||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.tts_checkpoint = tts_checkpoint
|
||||
self.tts_config_path = tts_config_path
|
||||
self.tts_speakers_file = tts_speakers_file
|
||||
self.tts_languages_file = tts_languages_file
|
||||
self.vocoder_checkpoint = vocoder_checkpoint
|
||||
self.vocoder_config = vocoder_config
|
||||
self.encoder_checkpoint = encoder_checkpoint
|
||||
self.encoder_config = encoder_config
|
||||
self.vc_checkpoint = vc_checkpoint
|
||||
self.vc_config = vc_config
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
self.tts_model = None
|
||||
self.vocoder_model = None
|
||||
self.vc_model = None
|
||||
self.speaker_manager = None
|
||||
self.tts_speakers = {}
|
||||
self.language_manager = None
|
||||
self.num_languages = 0
|
||||
self.tts_languages = {}
|
||||
self.d_vector_dim = 0
|
||||
self.seg = self._get_segmenter("en")
|
||||
self.use_cuda = use_cuda
|
||||
self.voice_dir = voice_dir
|
||||
if self.use_cuda:
|
||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||
|
||||
if tts_checkpoint:
|
||||
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
|
||||
if vocoder_checkpoint:
|
||||
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
||||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||
|
||||
if vc_checkpoint:
|
||||
self._load_vc(vc_checkpoint, vc_config, use_cuda)
|
||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||
|
||||
if model_dir:
|
||||
if "fairseq" in model_dir:
|
||||
self._load_fairseq_from_dir(model_dir, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
else:
|
||||
self._load_tts_from_dir(model_dir, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
||||
|
||||
@staticmethod
|
||||
def _get_segmenter(lang: str):
|
||||
"""get the sentence segmenter for the given language.
|
||||
|
||||
Args:
|
||||
lang (str): target language code.
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
return pysbd.Segmenter(language=lang, clean=True)
|
||||
|
||||
def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None:
|
||||
"""Load the voice conversion model.
|
||||
|
||||
1. Load the model config.
|
||||
2. Init the model from the config.
|
||||
3. Load the model weights.
|
||||
4. Move the model to the GPU if CUDA is enabled.
|
||||
|
||||
Args:
|
||||
vc_checkpoint (str): path to the model checkpoint.
|
||||
tts_config_path (str): path to the model config file.
|
||||
use_cuda (bool): enable/disable CUDA use.
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
self.vc_config = load_config(vc_config_path)
|
||||
self.vc_model = setup_vc_model(config=self.vc_config)
|
||||
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
|
||||
if use_cuda:
|
||||
self.vc_model.cuda()
|
||||
|
||||
def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
||||
"""Load the fairseq model from a directory.
|
||||
|
||||
We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory.
|
||||
"""
|
||||
self.tts_config = VitsConfig()
|
||||
self.tts_model = Vits.init_from_config(self.tts_config)
|
||||
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
|
||||
self.tts_config = self.tts_model.config
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
||||
"""Load the TTS model from a directory.
|
||||
|
||||
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
|
||||
"""
|
||||
config = load_config(os.path.join(model_dir, "config.json"))
|
||||
self.tts_config = config
|
||||
self.tts_model = setup_tts_model(config)
|
||||
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||
"""Load the TTS model.
|
||||
|
||||
1. Load the model config.
|
||||
2. Init the model from the config.
|
||||
3. Load the model weights.
|
||||
4. Move the model to the GPU if CUDA is enabled.
|
||||
5. Init the speaker manager in the model.
|
||||
|
||||
Args:
|
||||
tts_checkpoint (str): path to the model checkpoint.
|
||||
tts_config_path (str): path to the model config file.
|
||||
use_cuda (bool): enable/disable CUDA use.
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
self.tts_config = load_config(tts_config_path)
|
||||
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
|
||||
raise ValueError("Phonemizer is not defined in the TTS config.")
|
||||
|
||||
self.tts_model = setup_tts_model(config=self.tts_config)
|
||||
|
||||
if not self.encoder_checkpoint:
|
||||
self._set_speaker_encoder_paths_from_tts_config()
|
||||
|
||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"):
|
||||
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config, use_cuda)
|
||||
|
||||
def _set_speaker_encoder_paths_from_tts_config(self):
|
||||
"""Set the encoder paths from the tts model config for models with speaker encoders."""
|
||||
if hasattr(self.tts_config, "model_args") and hasattr(
|
||||
self.tts_config.model_args, "speaker_encoder_config_path"
|
||||
):
|
||||
self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
|
||||
self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
|
||||
|
||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||
"""Load the vocoder model.
|
||||
|
||||
1. Load the vocoder config.
|
||||
2. Init the AudioProcessor for the vocoder.
|
||||
3. Init the vocoder model from the config.
|
||||
4. Move the model to the GPU if CUDA is enabled.
|
||||
|
||||
Args:
|
||||
model_file (str): path to the model checkpoint.
|
||||
model_config (str): path to the model config file.
|
||||
use_cuda (bool): enable/disable CUDA use.
|
||||
"""
|
||||
self.vocoder_config = load_config(model_config)
|
||||
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio)
|
||||
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
|
||||
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
|
||||
if use_cuda:
|
||||
self.vocoder_model.cuda()
|
||||
|
||||
def split_into_sentences(self, text) -> List[str]:
|
||||
"""Split give text into sentences.
|
||||
|
||||
Args:
|
||||
text (str): input text in string format.
|
||||
|
||||
Returns:
|
||||
List[str]: list of sentences.
|
||||
"""
|
||||
return self.seg.segment(text)
|
||||
|
||||
def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None:
|
||||
"""Save the waveform as a file.
|
||||
|
||||
Args:
|
||||
wav (List[int]): waveform as a list of values.
|
||||
path (str): output path to save the waveform.
|
||||
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
|
||||
"""
|
||||
# if tensor convert to numpy
|
||||
if torch.is_tensor(wav):
|
||||
wav = wav.cpu().numpy()
|
||||
if isinstance(wav, list):
|
||||
wav = np.array(wav)
|
||||
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out)
|
||||
|
||||
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
|
||||
output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
|
||||
return output_wav
|
||||
|
||||
def tts(
|
||||
self,
|
||||
text: str = "",
|
||||
speaker_name: str = "",
|
||||
language_name: str = "",
|
||||
speaker_wav=None,
|
||||
style_wav=None,
|
||||
style_text=None,
|
||||
reference_wav=None,
|
||||
reference_speaker_name=None,
|
||||
split_sentences: bool = True,
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
"""🐸 TTS magic. Run all the models and generate speech.
|
||||
|
||||
Args:
|
||||
text (str): input text.
|
||||
speaker_name (str, optional): speaker id for multi-speaker models. Defaults to "".
|
||||
language_name (str, optional): language id for multi-language models. Defaults to "".
|
||||
speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None.
|
||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
|
||||
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
|
||||
reference_speaker_name ([type], optional): speaker id of reference waveform. Defaults to None.
|
||||
split_sentences (bool, optional): split the input text into sentences. Defaults to True.
|
||||
**kwargs: additional arguments to pass to the TTS model.
|
||||
Returns:
|
||||
List[int]: [description]
|
||||
"""
|
||||
start_time = time.time()
|
||||
wavs = []
|
||||
|
||||
if not text and not reference_wav:
|
||||
raise ValueError(
|
||||
"You need to define either `text` (for sythesis) or a `reference_wav` (for voice conversion) to use the Coqui TTS API."
|
||||
)
|
||||
|
||||
if text:
|
||||
sens = [text]
|
||||
if split_sentences:
|
||||
print(" > Text splitted to sentences.")
|
||||
sens = self.split_into_sentences(text)
|
||||
print(sens)
|
||||
|
||||
# handle multi-speaker
|
||||
if "voice_dir" in kwargs:
|
||||
self.voice_dir = kwargs["voice_dir"]
|
||||
kwargs.pop("voice_dir")
|
||||
speaker_embedding = None
|
||||
speaker_id = None
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
||||
if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts":
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the average speaker embedding from the saved d_vectors.
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
|
||||
speaker_name, num_samples=None, randomize=False
|
||||
)
|
||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||
else:
|
||||
# get speaker idx from the speaker name
|
||||
speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name]
|
||||
# handle Neon models with single speaker.
|
||||
elif len(self.tts_model.speaker_manager.name_to_id) == 1:
|
||||
speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0]
|
||||
elif not speaker_name and not speaker_wav:
|
||||
raise ValueError(
|
||||
" [!] Looks like you are using a multi-speaker model. "
|
||||
"You need to define either a `speaker_idx` or a `speaker_wav` to use a multi-speaker model."
|
||||
)
|
||||
else:
|
||||
speaker_embedding = None
|
||||
else:
|
||||
if speaker_name and self.voice_dir is None:
|
||||
raise ValueError(
|
||||
f" [!] Missing speakers.json file path for selecting speaker {speaker_name}."
|
||||
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
||||
)
|
||||
|
||||
# handle multi-lingual
|
||||
language_id = None
|
||||
if self.tts_languages_file or (
|
||||
hasattr(self.tts_model, "language_manager")
|
||||
and self.tts_model.language_manager is not None
|
||||
and not self.tts_config.model == "xtts"
|
||||
):
|
||||
if len(self.tts_model.language_manager.name_to_id) == 1:
|
||||
language_id = list(self.tts_model.language_manager.name_to_id.values())[0]
|
||||
|
||||
elif language_name and isinstance(language_name, str):
|
||||
try:
|
||||
language_id = self.tts_model.language_manager.name_to_id[language_name]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f" [!] Looks like you use a multi-lingual model. "
|
||||
f"Language {language_name} is not in the available languages: "
|
||||
f"{self.tts_model.language_manager.name_to_id.keys()}."
|
||||
) from e
|
||||
|
||||
elif not language_name:
|
||||
raise ValueError(
|
||||
" [!] Look like you use a multi-lingual model. "
|
||||
"You need to define either a `language_name` or a `style_wav` to use a multi-lingual model."
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f" [!] Missing language_ids.json file path for selecting language {language_name}."
|
||||
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
|
||||
)
|
||||
|
||||
# compute a new d_vector from the given clip.
|
||||
if (
|
||||
speaker_wav is not None
|
||||
and self.tts_model.speaker_manager is not None
|
||||
and hasattr(self.tts_model.speaker_manager, "encoder_ap")
|
||||
and self.tts_model.speaker_manager.encoder_ap is not None
|
||||
):
|
||||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
||||
|
||||
vocoder_device = "cpu"
|
||||
use_gl = self.vocoder_model is None
|
||||
if not use_gl:
|
||||
vocoder_device = next(self.vocoder_model.parameters()).device
|
||||
if self.use_cuda:
|
||||
vocoder_device = "cuda"
|
||||
|
||||
if not reference_wav: # not voice conversion
|
||||
for sen in sens:
|
||||
if hasattr(self.tts_model, "synthesize"):
|
||||
outputs = self.tts_model.synthesize(
|
||||
text=sen,
|
||||
config=self.tts_config,
|
||||
speaker_id=speaker_name,
|
||||
voice_dirs=self.voice_dir,
|
||||
d_vector=speaker_embedding,
|
||||
speaker_wav=speaker_wav,
|
||||
language=language_name,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# synthesize voice
|
||||
outputs = synthesis(
|
||||
model=self.tts_model,
|
||||
text=sen,
|
||||
CONFIG=self.tts_config,
|
||||
use_cuda=self.use_cuda,
|
||||
speaker_id=speaker_id,
|
||||
style_wav=style_wav,
|
||||
style_text=style_text,
|
||||
use_griffin_lim=use_gl,
|
||||
d_vector=speaker_embedding,
|
||||
language_id=language_id,
|
||||
)
|
||||
waveform = outputs["wav"]
|
||||
if not use_gl:
|
||||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
||||
# renormalize spectrogram based on vocoder config
|
||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||
# compute scale factor for possible sample rate mismatch
|
||||
scale_factor = [
|
||||
1,
|
||||
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
|
||||
]
|
||||
if scale_factor[1] != 1:
|
||||
print(" > interpolating tts model output.")
|
||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||
else:
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
# run vocoder model
|
||||
# [1, T, C]
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
|
||||
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl:
|
||||
waveform = waveform.cpu()
|
||||
if not use_gl:
|
||||
waveform = waveform.numpy()
|
||||
waveform = waveform.squeeze()
|
||||
|
||||
# trim silence
|
||||
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
|
||||
waveform = trim_silence(waveform, self.tts_model.ap)
|
||||
|
||||
wavs += list(waveform)
|
||||
wavs += [0] * 10000
|
||||
else:
|
||||
# get the speaker embedding or speaker id for the reference wav file
|
||||
reference_speaker_embedding = None
|
||||
reference_speaker_id = None
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
||||
if reference_speaker_name and isinstance(reference_speaker_name, str):
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the speaker embedding from the saved d_vectors.
|
||||
reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name(
|
||||
reference_speaker_name
|
||||
)[0]
|
||||
reference_speaker_embedding = np.array(reference_speaker_embedding)[
|
||||
None, :
|
||||
] # [1 x embedding_dim]
|
||||
else:
|
||||
# get speaker idx from the speaker name
|
||||
reference_speaker_id = self.tts_model.speaker_manager.name_to_id[reference_speaker_name]
|
||||
else:
|
||||
reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(
|
||||
reference_wav
|
||||
)
|
||||
outputs = transfer_voice(
|
||||
model=self.tts_model,
|
||||
CONFIG=self.tts_config,
|
||||
use_cuda=self.use_cuda,
|
||||
reference_wav=reference_wav,
|
||||
speaker_id=speaker_id,
|
||||
d_vector=speaker_embedding,
|
||||
use_griffin_lim=use_gl,
|
||||
reference_speaker_id=reference_speaker_id,
|
||||
reference_d_vector=reference_speaker_embedding,
|
||||
)
|
||||
waveform = outputs
|
||||
if not use_gl:
|
||||
mel_postnet_spec = outputs[0].detach().cpu().numpy()
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
||||
# renormalize spectrogram based on vocoder config
|
||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||
# compute scale factor for possible sample rate mismatch
|
||||
scale_factor = [
|
||||
1,
|
||||
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
|
||||
]
|
||||
if scale_factor[1] != 1:
|
||||
print(" > interpolating tts model output.")
|
||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||
else:
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
# run vocoder model
|
||||
# [1, T, C]
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
|
||||
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"):
|
||||
waveform = waveform.cpu()
|
||||
if not use_gl:
|
||||
waveform = waveform.numpy()
|
||||
wavs = waveform.squeeze()
|
||||
|
||||
# compute stats
|
||||
process_time = time.time() - start_time
|
||||
audio_time = len(wavs) / self.tts_config.audio["sample_rate"]
|
||||
print(f" > Processing time: {process_time}")
|
||||
print(f" > Real-time factor: {process_time / audio_time}")
|
||||
return wavs
|
||||
@@ -0,0 +1,44 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||||
r"""Check model gradient against unexpected jumps and failures"""
|
||||
skip_flag = False
|
||||
if ignore_stopnet:
|
||||
if not amp_opt_params:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
[param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip
|
||||
)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||||
else:
|
||||
if not amp_opt_params:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||||
|
||||
# compatibility with different torch versions
|
||||
if isinstance(grad_norm, float):
|
||||
if np.isinf(grad_norm):
|
||||
print(" | > Gradient is INF !!")
|
||||
skip_flag = True
|
||||
else:
|
||||
if torch.isinf(grad_norm):
|
||||
print(" | > Gradient is INF !!")
|
||||
skip_flag = True
|
||||
return grad_norm, skip_flag
|
||||
|
||||
|
||||
def gradual_training_scheduler(global_step, config):
|
||||
"""Setup the gradual training schedule wrt number
|
||||
of active GPUs"""
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus == 0:
|
||||
num_gpus = 1
|
||||
new_values = None
|
||||
# we set the scheduling wrt num_gpus
|
||||
for values in config.gradual_training:
|
||||
if global_step * num_gpus >= values[0]:
|
||||
new_values = values
|
||||
return new_values[1], new_values[2]
|
||||
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
def read_audio(path):
|
||||
wav, sr = torchaudio.load(path)
|
||||
|
||||
if wav.size(0) > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
return wav.squeeze(0), sr
|
||||
|
||||
|
||||
def resample_wav(wav, sr, new_sr):
|
||||
wav = wav.unsqueeze(0)
|
||||
transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr)
|
||||
wav = transform(wav)
|
||||
return wav.squeeze(0)
|
||||
|
||||
|
||||
def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False):
|
||||
factor = new_sr / vad_sr
|
||||
new_timestamps = []
|
||||
if just_begging_end and timestamps:
|
||||
# get just the start and end timestamps
|
||||
new_dict = {"start": int(timestamps[0]["start"] * factor), "end": int(timestamps[-1]["end"] * factor)}
|
||||
new_timestamps.append(new_dict)
|
||||
else:
|
||||
for ts in timestamps:
|
||||
# map to the new SR
|
||||
new_dict = {"start": int(ts["start"] * factor), "end": int(ts["end"] * factor)}
|
||||
new_timestamps.append(new_dict)
|
||||
|
||||
return new_timestamps
|
||||
|
||||
|
||||
def get_vad_model_and_utils(use_cuda=False, use_onnx=False):
|
||||
model, utils = torch.hub.load(
|
||||
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=use_onnx, force_onnx_cpu=True
|
||||
)
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
|
||||
get_speech_timestamps, save_audio, _, _, collect_chunks = utils
|
||||
return model, get_speech_timestamps, save_audio, collect_chunks
|
||||
|
||||
|
||||
def remove_silence(
|
||||
model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False
|
||||
):
|
||||
# get the VAD model and utils functions
|
||||
model, get_speech_timestamps, _, collect_chunks = model_and_utils
|
||||
|
||||
# read ground truth wav and resample the audio for the VAD
|
||||
try:
|
||||
wav, gt_sample_rate = read_audio(audio_path)
|
||||
except:
|
||||
print(f"> ❗ Failed to read {audio_path}")
|
||||
return None, False
|
||||
|
||||
# if needed, resample the audio for the VAD model
|
||||
if gt_sample_rate != vad_sample_rate:
|
||||
wav_vad = resample_wav(wav, gt_sample_rate, vad_sample_rate)
|
||||
else:
|
||||
wav_vad = wav
|
||||
|
||||
if use_cuda:
|
||||
wav_vad = wav_vad.cuda()
|
||||
|
||||
# get speech timestamps from full audio file
|
||||
speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768)
|
||||
|
||||
# map the current speech_timestamps to the sample rate of the ground truth audio
|
||||
new_speech_timestamps = map_timestamps_to_new_sr(
|
||||
vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end
|
||||
)
|
||||
|
||||
# if have speech timestamps else save the wav
|
||||
if new_speech_timestamps:
|
||||
wav = collect_chunks(new_speech_timestamps, wav)
|
||||
is_speech = True
|
||||
else:
|
||||
print(f"> The file {audio_path} probably does not have speech please check it !!")
|
||||
is_speech = False
|
||||
|
||||
# save
|
||||
torchaudio.save(out_path, wav[None, :], gt_sample_rate)
|
||||
return out_path, is_speech
|
||||
Reference in New Issue
Block a user