Add files via upload

This commit is contained in:
Sam Khoze
2024-06-18 13:18:37 -07:00
committed by GitHub
parent 9e69cc24ea
commit 6af40bc6cf
74 changed files with 8332 additions and 0 deletions
View File
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+1
View File
@@ -0,0 +1 @@
from TTS.utils.audio.processor import AudioProcessor
Binary file not shown.
+485
View File
@@ -0,0 +1,485 @@
from io import BytesIO
from typing import Tuple
import librosa
import numpy as np
import scipy
import soundfile as sf
from librosa import magphase, pyin
# For using kwargs
# pylint: disable=unused-argument
def build_mel_basis(
*,
sample_rate: int = None,
fft_size: int = None,
num_mels: int = None,
mel_fmax: int = None,
mel_fmin: int = None,
**kwargs,
) -> np.ndarray:
"""Build melspectrogram basis.
Returns:
np.ndarray: melspectrogram basis.
"""
if mel_fmax is not None:
assert mel_fmax <= sample_rate // 2
assert mel_fmax - mel_fmin > 0
return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax)
def millisec_to_length(
*, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs
) -> Tuple[int, int]:
"""Compute hop and window length from milliseconds.
Returns:
Tuple[int, int]: hop length and window length for STFT.
"""
factor = frame_length_ms / frame_shift_ms
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
win_length = int(frame_length_ms / 1000.0 * sample_rate)
hop_length = int(win_length / float(factor))
return win_length, hop_length
def _log(x, base):
if base == 10:
return np.log10(x)
return np.log(x)
def _exp(x, base):
if base == 10:
return np.power(10, x)
return np.exp(x)
def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
"""Convert amplitude values to decibels.
Args:
x (np.ndarray): Amplitude spectrogram.
gain (float): Gain factor. Defaults to 1.
base (int): Logarithm base. Defaults to 10.
Returns:
np.ndarray: Decibels spectrogram.
"""
assert (x < 0).sum() == 0, " [!] Input values must be non-negative."
return gain * _log(np.maximum(1e-8, x), base)
# pylint: disable=no-self-use
def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
"""Convert decibels spectrogram to amplitude spectrogram.
Args:
x (np.ndarray): Decibels spectrogram.
gain (float): Gain factor. Defaults to 1.
base (int): Logarithm base. Defaults to 10.
Returns:
np.ndarray: Amplitude spectrogram.
"""
return _exp(x / gain, base)
def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
Args:
x (np.ndarray): Audio signal.
Raises:
RuntimeError: Preemphasis coeff is set to 0.
Returns:
np.ndarray: Decorrelated audio signal.
"""
if coef == 0:
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1, -coef], [1], x)
def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray:
"""Reverse pre-emphasis."""
if coef == 0:
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1], [1, -coef], x)
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
Args:
spec (np.ndarray): Normalized full scale linear spectrogram.
Shapes:
- spec: :math:`[C, T]`
Returns:
np.ndarray: Normalized melspectrogram.
"""
return np.dot(mel_basis, spec)
def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
"""Convert a melspectrogram to full scale spectrogram."""
assert (mel < 0).sum() == 0, " [!] Input values must be non-negative."
inv_mel_basis = np.linalg.pinv(mel_basis)
return np.maximum(1e-10, np.dot(inv_mel_basis, mel))
def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray:
"""Compute a spectrogram from a waveform.
Args:
wav (np.ndarray): Waveform. Shape :math:`[T_wav,]`
Returns:
np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length`
"""
D = stft(y=wav, **kwargs)
S = np.abs(D)
return S.astype(np.float32)
def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray:
"""Compute a melspectrogram from a waveform."""
D = stft(y=wav, **kwargs)
S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs)
return S.astype(np.float32)
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
S = spec.copy()
return griffin_lim(spec=S**power, **kwargs)
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
S = mel.copy()
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
return griffin_lim(spec=S**power, **kwargs)
### STFT and ISTFT ###
def stft(
*,
y: np.ndarray = None,
fft_size: int = None,
hop_length: int = None,
win_length: int = None,
pad_mode: str = "reflect",
window: str = "hann",
center: bool = True,
**kwargs,
) -> np.ndarray:
"""Librosa STFT wrapper.
Check http://librosa.org/doc/main/generated/librosa.stft.html argument details.
Returns:
np.ndarray: Complex number array.
"""
return librosa.stft(
y=y,
n_fft=fft_size,
hop_length=hop_length,
win_length=win_length,
pad_mode=pad_mode,
window=window,
center=center,
)
def istft(
*,
y: np.ndarray = None,
hop_length: int = None,
win_length: int = None,
window: str = "hann",
center: bool = True,
**kwargs,
) -> np.ndarray:
"""Librosa iSTFT wrapper.
Check http://librosa.org/doc/main/generated/librosa.istft.html argument details.
Returns:
np.ndarray: Complex number array.
"""
return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window)
def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray:
angles = np.exp(2j * np.pi * np.random.rand(*spec.shape))
S_complex = np.abs(spec).astype(complex)
y = istft(y=S_complex * angles, **kwargs)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(num_iter):
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))
y = istft(y=S_complex * angles, **kwargs)
return y
def compute_stft_paddings(
*, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs
) -> Tuple[int, int]:
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
(first and final frames)"""
pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0]
if not pad_two_sides:
return 0, pad
return pad // 2, pad // 2 + pad % 2
def compute_f0(
*,
x: np.ndarray = None,
pitch_fmax: float = None,
pitch_fmin: float = None,
hop_length: int = None,
win_length: int = None,
sample_rate: int = None,
stft_pad_mode: str = "reflect",
center: bool = True,
**kwargs,
) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
pitch_fmax (float): Pitch max value.
pitch_fmin (float): Pitch min value.
hop_length (int): Number of frames between STFT columns.
win_length (int): STFT window length.
sample_rate (int): Audio sampling rate.
stft_pad_mode (str): Padding mode for STFT.
center (bool): Centered padding.
Returns:
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
Examples:
>>> WAV_FILE = filename = librosa.example('vibeace')
>>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
>>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
>>> pitch = ap.compute_f0(wav)
"""
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
assert pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
f0, voiced_mask, _ = pyin(
y=x.astype(np.double),
fmin=pitch_fmin,
fmax=pitch_fmax,
sr=sample_rate,
frame_length=win_length,
win_length=win_length // 2,
hop_length=hop_length,
pad_mode=stft_pad_mode,
center=center,
n_thresholds=100,
beta_parameters=(2, 18),
boltzmann_parameter=2,
resolution=0.1,
max_transition_rate=35.92,
switch_prob=0.01,
no_trough_prob=0.01,
)
f0[~voiced_mask] = 0.0
return f0
def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray:
"""Compute energy of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
Returns:
np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length`
Examples:
>>> WAV_FILE = filename = librosa.example('vibeace')
>>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig()
>>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
>>> energy = ap.compute_energy(wav)
"""
x = stft(y=y, **kwargs)
mag, _ = magphase(x)
energy = np.sqrt(np.sum(mag**2, axis=0))
return energy
### Audio Processing ###
def find_endpoint(
*,
wav: np.ndarray = None,
trim_db: float = -40,
sample_rate: int = None,
min_silence_sec=0.8,
gain: float = None,
base: int = None,
**kwargs,
) -> int:
"""Find the last point without silence at the end of a audio signal.
Args:
wav (np.ndarray): Audio signal.
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None.
base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10.
Returns:
int: Last point without silence.
"""
window_length = int(sample_rate * min_silence_sec)
hop_length = int(window_length / 4)
threshold = db_to_amp(x=-trim_db, gain=gain, base=base)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x : x + window_length]) < threshold:
return x + hop_length
return len(wav)
def trim_silence(
*,
wav: np.ndarray = None,
sample_rate: int = None,
trim_db: float = None,
win_length: int = None,
hop_length: int = None,
**kwargs,
) -> np.ndarray:
"""Trim silent parts with a threshold and 0.01 sec margin"""
margin = int(sample_rate * 0.01)
wav = wav[margin:-margin]
return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0]
def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray:
"""Normalize the volume of an audio signal.
Args:
x (np.ndarray): Raw waveform.
coef (float): Coefficient to rescale the maximum value. Defaults to 0.95.
Returns:
np.ndarray: Volume normalized waveform.
"""
return x / abs(x).max() * coef
def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray:
r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a
def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray:
"""Normalize the volume based on RMS of the signal.
Args:
x (np.ndarray): Raw waveform.
db_level (float): Target dB level in RMS. Defaults to -27.0.
Returns:
np.ndarray: RMS normalized waveform.
"""
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
wav = rms_norm(wav=x, db_level=db_level)
return wav
def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
Args:
filename (str): Path to the wav file.
sr (int, optional): Sampling rate for resampling. Defaults to None.
resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False.
Returns:
np.ndarray: Loaded waveform.
"""
if resample:
# loading with resampling. It is significantly slower.
x, _ = librosa.load(filename, sr=sample_rate)
else:
# SF is faster than librosa for loading files
x, _ = sf.read(filename)
return x
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None:
"""Save float waveform to a file using Scipy.
Args:
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
path (str): Path to a output file.
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
"""
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
wav_norm = wav_norm.astype(np.int16)
if pipe_out:
wav_buffer = BytesIO()
scipy.io.wavfile.write(wav_buffer, sample_rate, wav_norm)
wav_buffer.seek(0)
pipe_out.buffer.write(wav_buffer.read())
scipy.io.wavfile.write(path, sample_rate, wav_norm)
def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
mu = 2**mulaw_qc - 1
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(
signal,
)
def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray:
"""Recovers waveform from quantized values."""
mu = 2**mulaw_qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray:
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray:
"""Quantize a waveform to a given number of bits.
Args:
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
quantize_bits (int): Number of quantization bits.
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2**quantize_bits - 1) / 2
def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray:
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2**quantize_bits - 1) - 1
+633
View File
@@ -0,0 +1,633 @@
from io import BytesIO
from typing import Dict, Tuple
import librosa
import numpy as np
import scipy.io.wavfile
import scipy.signal
from TTS.tts.utils.helpers import StandardScaler
from TTS.utils.audio.numpy_transforms import (
amp_to_db,
build_mel_basis,
compute_f0,
db_to_amp,
deemphasis,
find_endpoint,
griffin_lim,
load_wav,
mel_to_spec,
millisec_to_length,
preemphasis,
rms_volume_norm,
spec_to_mel,
stft,
trim_silence,
volume_norm,
)
# pylint: disable=too-many-public-methods
class AudioProcessor(object):
"""Audio Processor for TTS.
Note:
All the class arguments are set to default values to enable a flexible initialization
of the class with the model config. They are not meaningful for all the arguments.
Args:
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
resample (bool, optional):
enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
num_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
log_func (int, optional):
log exponent used for converting spectrogram aplitude to DB.
min_level_db (int, optional):
minimum db threshold for the computed melspectrograms. Defaults to None.
frame_shift_ms (int, optional):
milliseconds of frames between STFT columns. Defaults to None.
frame_length_ms (int, optional):
milliseconds of STFT window length. Defaults to None.
hop_length (int, optional):
number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
win_length (int, optional):
STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
ref_level_db (int, optional):
reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
fft_size (int, optional):
FFT window size for STFT. Defaults to 1024.
power (int, optional):
Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
preemphasis (float, optional):
Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
signal_norm (bool, optional):
enable/disable signal normalization. Defaults to None.
symmetric_norm (bool, optional):
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
max_norm (float, optional):
```k``` defining the normalization range. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None.
pitch_fmin (int, optional):
minimum filter frequency for computing pitch. Defaults to None.
pitch_fmax (int, optional):
maximum filter frequency for computing pitch. Defaults to None.
spec_gain (int, optional):
gain applied when converting amplitude to DB. Defaults to 20.
stft_pad_mode (str, optional):
Padding mode for STFT. Defaults to 'reflect'.
clip_norm (bool, optional):
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
griffin_lim_iters (int, optional):
Number of GriffinLim iterations. Defaults to None.
do_trim_silence (bool, optional):
enable/disable silence trimming when loading the audio signal. Defaults to False.
trim_db (int, optional):
DB threshold used for silence trimming. Defaults to 60.
do_sound_norm (bool, optional):
enable/disable signal normalization. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
do_rms_norm (bool, optional):
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
db_level (int, optional):
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
stats_path (str, optional):
Path to the computed stats file. Defaults to None.
verbose (bool, optional):
enable/disable logging. Defaults to True.
"""
def __init__(
self,
sample_rate=None,
resample=False,
num_mels=None,
log_func="np.log10",
min_level_db=None,
frame_shift_ms=None,
frame_length_ms=None,
hop_length=None,
win_length=None,
ref_level_db=None,
fft_size=1024,
power=None,
preemphasis=0.0,
signal_norm=None,
symmetric_norm=None,
max_norm=None,
mel_fmin=None,
mel_fmax=None,
pitch_fmax=None,
pitch_fmin=None,
spec_gain=20,
stft_pad_mode="reflect",
clip_norm=True,
griffin_lim_iters=None,
do_trim_silence=False,
trim_db=60,
do_sound_norm=False,
do_amp_to_db_linear=True,
do_amp_to_db_mel=True,
do_rms_norm=False,
db_level=None,
stats_path=None,
verbose=True,
**_,
):
# setup class attributed
self.sample_rate = sample_rate
self.resample = resample
self.num_mels = num_mels
self.log_func = log_func
self.min_level_db = min_level_db or 0
self.frame_shift_ms = frame_shift_ms
self.frame_length_ms = frame_length_ms
self.ref_level_db = ref_level_db
self.fft_size = fft_size
self.power = power
self.preemphasis = preemphasis
self.griffin_lim_iters = griffin_lim_iters
self.signal_norm = signal_norm
self.symmetric_norm = symmetric_norm
self.mel_fmin = mel_fmin or 0
self.mel_fmax = mel_fmax
self.pitch_fmin = pitch_fmin
self.pitch_fmax = pitch_fmax
self.spec_gain = float(spec_gain)
self.stft_pad_mode = stft_pad_mode
self.max_norm = 1.0 if max_norm is None else float(max_norm)
self.clip_norm = clip_norm
self.do_trim_silence = do_trim_silence
self.trim_db = trim_db
self.do_sound_norm = do_sound_norm
self.do_amp_to_db_linear = do_amp_to_db_linear
self.do_amp_to_db_mel = do_amp_to_db_mel
self.do_rms_norm = do_rms_norm
self.db_level = db_level
self.stats_path = stats_path
# setup exp_func for db to amp conversion
if log_func == "np.log":
self.base = np.e
elif log_func == "np.log10":
self.base = 10
else:
raise ValueError(" [!] unknown `log_func` value.")
# setup stft parameters
if hop_length is None:
# compute stft parameters from given time values
self.win_length, self.hop_length = millisec_to_length(
frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate
)
else:
# use stft parameters from config file
self.hop_length = hop_length
self.win_length = win_length
assert min_level_db != 0.0, " [!] min_level_db is 0"
assert (
self.win_length <= self.fft_size
), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
members = vars(self)
if verbose:
print(" > Setting up Audio Processor...")
for key, value in members.items():
print(" | > {}:{}".format(key, value))
# create spectrogram utils
self.mel_basis = build_mel_basis(
sample_rate=self.sample_rate,
fft_size=self.fft_size,
num_mels=self.num_mels,
mel_fmax=self.mel_fmax,
mel_fmin=self.mel_fmin,
)
# setup scaler
if stats_path and signal_norm:
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)
self.signal_norm = True
self.max_norm = None
self.clip_norm = None
self.symmetric_norm = None
@staticmethod
def init_from_config(config: "Coqpit", verbose=True):
if "audio" in config:
return AudioProcessor(verbose=verbose, **config.audio)
return AudioProcessor(verbose=verbose, **config)
### normalization ###
def normalize(self, S: np.ndarray) -> np.ndarray:
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
Args:
S (np.ndarray): Spectrogram to normalize.
Raises:
RuntimeError: Mean and variance is computed from incompatible parameters.
Returns:
np.ndarray: Normalized spectrogram.
"""
# pylint: disable=no-else-return
S = S.copy()
if self.signal_norm:
# mean-var scaling
if hasattr(self, "mel_scaler"):
if S.shape[0] == self.num_mels:
return self.mel_scaler.transform(S.T).T
elif S.shape[0] == self.fft_size / 2:
return self.linear_scaler.transform(S.T).T
else:
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
# range normalization
S -= self.ref_level_db # discard certain range of DB assuming it is air noise
S_norm = (S - self.min_level_db) / (-self.min_level_db)
if self.symmetric_norm:
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
if self.clip_norm:
S_norm = np.clip(
S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
)
return S_norm
else:
S_norm = self.max_norm * S_norm
if self.clip_norm:
S_norm = np.clip(S_norm, 0, self.max_norm)
return S_norm
else:
return S
def denormalize(self, S: np.ndarray) -> np.ndarray:
"""Denormalize spectrogram values.
Args:
S (np.ndarray): Spectrogram to denormalize.
Raises:
RuntimeError: Mean and variance are incompatible.
Returns:
np.ndarray: Denormalized spectrogram.
"""
# pylint: disable=no-else-return
S_denorm = S.copy()
if self.signal_norm:
# mean-var scaling
if hasattr(self, "mel_scaler"):
if S_denorm.shape[0] == self.num_mels:
return self.mel_scaler.inverse_transform(S_denorm.T).T
elif S_denorm.shape[0] == self.fft_size / 2:
return self.linear_scaler.inverse_transform(S_denorm.T).T
else:
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
if self.symmetric_norm:
if self.clip_norm:
S_denorm = np.clip(
S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
)
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
return S_denorm + self.ref_level_db
else:
if self.clip_norm:
S_denorm = np.clip(S_denorm, 0, self.max_norm)
S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db
return S_denorm + self.ref_level_db
else:
return S_denorm
### Mean-STD scaling ###
def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]:
"""Loading mean and variance statistics from a `npy` file.
Args:
stats_path (str): Path to the `npy` file containing
Returns:
Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to
compute them.
"""
stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg
mel_mean = stats["mel_mean"]
mel_std = stats["mel_std"]
linear_mean = stats["linear_mean"]
linear_std = stats["linear_std"]
stats_config = stats["audio_config"]
# check all audio parameters used for computing stats
skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"]
for key in stats_config.keys():
if key in skip_parameters:
continue
if key not in ["sample_rate", "trim_db"]:
assert (
stats_config[key] == self.__dict__[key]
), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
return mel_mean, mel_std, linear_mean, linear_std, stats_config
# pylint: disable=attribute-defined-outside-init
def setup_scaler(
self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray
) -> None:
"""Initialize scaler objects used in mean-std normalization.
Args:
mel_mean (np.ndarray): Mean for melspectrograms.
mel_std (np.ndarray): STD for melspectrograms.
linear_mean (np.ndarray): Mean for full scale spectrograms.
linear_std (np.ndarray): STD for full scale spectrograms.
"""
self.mel_scaler = StandardScaler()
self.mel_scaler.set_stats(mel_mean, mel_std)
self.linear_scaler = StandardScaler()
self.linear_scaler.set_stats(linear_mean, linear_std)
### Preemphasis ###
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
Args:
x (np.ndarray): Audio signal.
Raises:
RuntimeError: Preemphasis coeff is set to 0.
Returns:
np.ndarray: Decorrelated audio signal.
"""
return preemphasis(x=x, coef=self.preemphasis)
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
"""Reverse pre-emphasis."""
return deemphasis(x=x, coef=self.preemphasis)
### SPECTROGRAMs ###
def spectrogram(self, y: np.ndarray) -> np.ndarray:
"""Compute a spectrogram from a waveform.
Args:
y (np.ndarray): Waveform.
Returns:
np.ndarray: Spectrogram.
"""
if self.preemphasis != 0:
y = self.apply_preemphasis(y)
D = stft(
y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
if self.do_amp_to_db_linear:
S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
else:
S = np.abs(D)
return self.normalize(S).astype(np.float32)
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
"""Compute a melspectrogram from a waveform."""
if self.preemphasis != 0:
y = self.apply_preemphasis(y)
D = stft(
y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis)
if self.do_amp_to_db_mel:
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
return self.normalize(S).astype(np.float32)
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
S = self.denormalize(spectrogram)
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
# Reconstruct phase
W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
D = self.denormalize(mel_spectrogram)
S = db_to_amp(x=D, gain=self.spec_gain, base=self.base)
S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear
W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
Args:
linear_spec (np.ndarray): Normalized full scale linear spectrogram.
Returns:
np.ndarray: Normalized melspectrogram.
"""
S = self.denormalize(linear_spec)
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
mel = self.normalize(S)
return mel
def _griffin_lim(self, S):
return griffin_lim(
spec=S,
num_iter=self.griffin_lim_iters,
hop_length=self.hop_length,
win_length=self.win_length,
fft_size=self.fft_size,
pad_mode=self.stft_pad_mode,
)
def compute_f0(self, x: np.ndarray) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform.
Returns:
np.ndarray: Pitch.
Examples:
>>> WAV_FILE = filename = librosa.example('vibeace')
>>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
>>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
>>> pitch = ap.compute_f0(wav)
"""
# align F0 length to the spectrogram length
if len(x) % self.hop_length == 0:
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
f0 = compute_f0(
x=x,
pitch_fmax=self.pitch_fmax,
pitch_fmin=self.pitch_fmin,
hop_length=self.hop_length,
win_length=self.win_length,
sample_rate=self.sample_rate,
stft_pad_mode=self.stft_pad_mode,
center=True,
)
return f0
### Audio Processing ###
def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int:
"""Find the last point without silence at the end of a audio signal.
Args:
wav (np.ndarray): Audio signal.
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
Returns:
int: Last point without silence.
"""
return find_endpoint(
wav=wav,
trim_db=self.trim_db,
sample_rate=self.sample_rate,
min_silence_sec=min_silence_sec,
gain=self.spec_gain,
base=self.base,
)
def trim_silence(self, wav):
"""Trim silent parts with a threshold and 0.01 sec margin"""
return trim_silence(
wav=wav,
sample_rate=self.sample_rate,
trim_db=self.trim_db,
win_length=self.win_length,
hop_length=self.hop_length,
)
@staticmethod
def sound_norm(x: np.ndarray) -> np.ndarray:
"""Normalize the volume of an audio signal.
Args:
x (np.ndarray): Raw waveform.
Returns:
np.ndarray: Volume normalized waveform.
"""
return volume_norm(x=x)
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
"""Normalize the volume based on RMS of the signal.
Args:
x (np.ndarray): Raw waveform.
Returns:
np.ndarray: RMS normalized waveform.
"""
if db_level is None:
db_level = self.db_level
return rms_volume_norm(x=x, db_level=db_level)
### save and load ###
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
Args:
filename (str): Path to the wav file.
sr (int, optional): Sampling rate for resampling. Defaults to None.
Returns:
np.ndarray: Loaded waveform.
"""
if sr is not None:
x = load_wav(filename=filename, sample_rate=sr, resample=True)
else:
x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample)
if self.do_trim_silence:
try:
x = self.trim_silence(x)
except ValueError:
print(f" [!] File cannot be trimmed for silence - {filename}")
if self.do_sound_norm:
x = self.sound_norm(x)
if self.do_rms_norm:
x = self.rms_volume_norm(x, self.db_level)
return x
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None:
"""Save a waveform to a file using Scipy.
Args:
wav (np.ndarray): Waveform to save.
path (str): Path to a output file.
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
"""
if self.do_rms_norm:
wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767
else:
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
wav_norm = wav_norm.astype(np.int16)
if pipe_out:
wav_buffer = BytesIO()
scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm)
wav_buffer.seek(0)
pipe_out.buffer.write(wav_buffer.read())
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm)
def get_duration(self, filename: str) -> float:
"""Get the duration of a wav file using Librosa.
Args:
filename (str): Path to the wav file.
"""
return librosa.get_duration(filename=filename)
+165
View File
@@ -0,0 +1,165 @@
import librosa
import torch
from torch import nn
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
Args:
n_fft (int):
FFT window size for STFT.
hop_length (int):
number of frames between STFT columns.
win_length (int, optional):
STFT window length.
pad_wav (bool, optional):
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
window (str, optional):
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
sample_rate (int, optional):
target audio sampling rate. Defaults to None.
mel_fmin (int, optional):
minimum filter frequency for computing melspectrograms. Defaults to None.
mel_fmax (int, optional):
maximum filter frequency for computing melspectrograms. Defaults to None.
n_mels (int, optional):
number of melspectrogram dimensions. Defaults to None.
use_mel (bool, optional):
If True compute the melspectrograms otherwise. Defaults to False.
do_amp_to_db_linear (bool, optional):
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
spec_gain (float, optional):
gain applied when converting amplitude to DB. Defaults to 1.0.
power (float, optional):
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
use_htk (bool, optional):
Use HTK formula in mel filter instead of Slaney.
mel_norm (None, 'slaney', or number, optional):
If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization).
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
See `librosa.util.normalize` for a full description of supported norm values
(including `+-np.inf`).
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
"""
def __init__(
self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window="hann_window",
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False,
do_amp_to_db=False,
spec_gain=1.0,
power=None,
use_htk=False,
mel_norm="slaney",
normalized=False,
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.pad_wav = pad_wav
self.sample_rate = sample_rate
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.do_amp_to_db = do_amp_to_db
self.spec_gain = spec_gain
self.power = power
self.use_htk = use_htk
self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
self.normalized = normalized
if use_mel:
self._build_mel_basis()
def __call__(self, x):
"""Compute spectrogram frames by torch based stft.
Args:
x (Tensor): input waveform
Returns:
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [:math:`[B, 1, T]`]
"""
if x.ndim == 2:
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.stft(
x.squeeze(1),
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=self.normalized,
onesided=True,
return_complex=False,
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None:
S = S**self.power
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
if self.do_amp_to_db:
S = self._amp_to_db(S, spec_gain=self.spec_gain)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
sr=self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()
@staticmethod
def _amp_to_db(x, spec_gain=1.0):
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
@staticmethod
def _db_to_amp(x, spec_gain=1.0):
return torch.exp(x) / spec_gain
+105
View File
@@ -0,0 +1,105 @@
class TrainerCallback:
@staticmethod
def on_init_start(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_init_start"):
trainer.model.module.on_init_start(trainer)
else:
if hasattr(trainer.model, "on_init_start"):
trainer.model.on_init_start(trainer)
if hasattr(trainer.criterion, "on_init_start"):
trainer.criterion.on_init_start(trainer)
if hasattr(trainer.optimizer, "on_init_start"):
trainer.optimizer.on_init_start(trainer)
@staticmethod
def on_init_end(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_init_end"):
trainer.model.module.on_init_end(trainer)
else:
if hasattr(trainer.model, "on_init_end"):
trainer.model.on_init_end(trainer)
if hasattr(trainer.criterion, "on_init_end"):
trainer.criterion.on_init_end(trainer)
if hasattr(trainer.optimizer, "on_init_end"):
trainer.optimizer.on_init_end(trainer)
@staticmethod
def on_epoch_start(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_epoch_start"):
trainer.model.module.on_epoch_start(trainer)
else:
if hasattr(trainer.model, "on_epoch_start"):
trainer.model.on_epoch_start(trainer)
if hasattr(trainer.criterion, "on_epoch_start"):
trainer.criterion.on_epoch_start(trainer)
if hasattr(trainer.optimizer, "on_epoch_start"):
trainer.optimizer.on_epoch_start(trainer)
@staticmethod
def on_epoch_end(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_epoch_end"):
trainer.model.module.on_epoch_end(trainer)
else:
if hasattr(trainer.model, "on_epoch_end"):
trainer.model.on_epoch_end(trainer)
if hasattr(trainer.criterion, "on_epoch_end"):
trainer.criterion.on_epoch_end(trainer)
if hasattr(trainer.optimizer, "on_epoch_end"):
trainer.optimizer.on_epoch_end(trainer)
@staticmethod
def on_train_step_start(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_step_start"):
trainer.model.module.on_train_step_start(trainer)
else:
if hasattr(trainer.model, "on_train_step_start"):
trainer.model.on_train_step_start(trainer)
if hasattr(trainer.criterion, "on_train_step_start"):
trainer.criterion.on_train_step_start(trainer)
if hasattr(trainer.optimizer, "on_train_step_start"):
trainer.optimizer.on_train_step_start(trainer)
@staticmethod
def on_train_step_end(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_step_end"):
trainer.model.module.on_train_step_end(trainer)
else:
if hasattr(trainer.model, "on_train_step_end"):
trainer.model.on_train_step_end(trainer)
if hasattr(trainer.criterion, "on_train_step_end"):
trainer.criterion.on_train_step_end(trainer)
if hasattr(trainer.optimizer, "on_train_step_end"):
trainer.optimizer.on_train_step_end(trainer)
@staticmethod
def on_keyboard_interrupt(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_keyboard_interrupt"):
trainer.model.module.on_keyboard_interrupt(trainer)
else:
if hasattr(trainer.model, "on_keyboard_interrupt"):
trainer.model.on_keyboard_interrupt(trainer)
if hasattr(trainer.criterion, "on_keyboard_interrupt"):
trainer.criterion.on_keyboard_interrupt(trainer)
if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
trainer.optimizer.on_keyboard_interrupt(trainer)
+67
View File
@@ -0,0 +1,67 @@
from typing import Generator
from trainer.trainer_utils import get_optimizer
class CapacitronOptimizer:
"""Double optimizer class for the Capacitron model."""
def __init__(self, config: dict, model_params: Generator) -> None:
self.primary_params, self.secondary_params = self.split_model_parameters(model_params)
optimizer_names = list(config.optimizer_params.keys())
optimizer_parameters = list(config.optimizer_params.values())
self.primary_optimizer = get_optimizer(
optimizer_names[0],
optimizer_parameters[0],
config.lr,
parameters=self.primary_params,
)
self.secondary_optimizer = get_optimizer(
optimizer_names[1],
self.extract_optimizer_parameters(optimizer_parameters[1]),
optimizer_parameters[1]["lr"],
parameters=self.secondary_params,
)
self.param_groups = self.primary_optimizer.param_groups
def first_step(self):
self.secondary_optimizer.step()
self.secondary_optimizer.zero_grad()
self.primary_optimizer.zero_grad()
def step(self):
# Update param groups to display the correct learning rate
self.param_groups = self.primary_optimizer.param_groups
self.primary_optimizer.step()
def zero_grad(self, set_to_none=False):
self.primary_optimizer.zero_grad(set_to_none)
self.secondary_optimizer.zero_grad(set_to_none)
def load_state_dict(self, state_dict):
self.primary_optimizer.load_state_dict(state_dict[0])
self.secondary_optimizer.load_state_dict(state_dict[1])
def state_dict(self):
return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()]
@staticmethod
def split_model_parameters(model_params: Generator) -> list:
primary_params = []
secondary_params = []
for name, param in model_params:
if param.requires_grad:
if name == "capacitron_vae_layer.beta":
secondary_params.append(param)
else:
primary_params.append(param)
return [iter(primary_params), iter(secondary_params)]
@staticmethod
def extract_optimizer_parameters(params: dict) -> dict:
"""Extract parameters that are not the learning rate"""
return {k: v for k, v in params.items() if k != "lr"}
+20
View File
@@ -0,0 +1,20 @@
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
import torch
import torch.distributed as dist
def reduce_tensor(tensor, num_gpus):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= num_gpus
return rt
def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
# Set cuda device so everything is done on the right GPU.
torch.cuda.set_device(rank % torch.cuda.device_count())
# Initialize distributed communication
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
+206
View File
@@ -0,0 +1,206 @@
# Adapted from https://github.com/pytorch/audio/
import hashlib
import logging
import os
import tarfile
import urllib
import urllib.request
import zipfile
from os.path import expanduser
from typing import Any, Iterable, List, Optional
from torch.utils.model_zoo import tqdm
def stream_url(
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
) -> Iterable:
"""Stream url by chunk
Args:
url (str): Url.
start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
"""
# If we already have the whole file, there is no need to download it again
req = urllib.request.Request(url, method="HEAD")
with urllib.request.urlopen(req) as response:
url_size = int(response.info().get("Content-Length", -1))
if url_size == start_byte:
return
req = urllib.request.Request(url)
if start_byte:
req.headers["Range"] = "bytes={}-".format(start_byte)
with urllib.request.urlopen(req) as upointer, tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=url_size,
disable=not progress_bar,
) as pbar:
num_bytes = 0
while True:
chunk = upointer.read(block_size)
if not chunk:
break
yield chunk
num_bytes += len(chunk)
pbar.update(len(chunk))
def download_url(
url: str,
download_folder: str,
filename: Optional[str] = None,
hash_value: Optional[str] = None,
hash_type: str = "sha256",
progress_bar: bool = True,
resume: bool = False,
) -> None:
"""Download file to disk.
Args:
url (str): Url.
download_folder (str): Folder to download file.
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
(Default: ``None``).
hash_value (str or None, optional): Hash for url (Default: ``None``).
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
resume (bool, optional): Enable resuming download (Default: ``False``).
"""
req = urllib.request.Request(url, method="HEAD")
req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with
# Detect filename
filename = filename or req_info.get_filename() or os.path.basename(url)
filepath = os.path.join(download_folder, filename)
if resume and os.path.exists(filepath):
mode = "ab"
local_size: Optional[int] = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath):
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
else:
mode = "wb"
local_size = None
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
with open(filepath, "rb") as file_obj:
if validate_file(file_obj, hash_value, hash_type):
return
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
with open(filepath, mode) as fpointer:
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
fpointer.write(chunk)
with open(filepath, "rb") as file_obj:
if hash_value and not validate_file(file_obj, hash_value, hash_type):
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
"""Validate a given file object with its hash.
Args:
file_obj: File object to read from.
hash_value (str): Hash for url.
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
Returns:
bool: return True if its a valid file, else False.
"""
if hash_type == "sha256":
hash_func = hashlib.sha256()
elif hash_type == "md5":
hash_func = hashlib.md5()
else:
raise ValueError
while True:
# Read by chunk to avoid filling memory
chunk = file_obj.read(1024**2)
if not chunk:
break
hash_func.update(chunk)
return hash_func.hexdigest() == hash_value
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
"""Extract archive.
Args:
from_path (str): the path of the archive.
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
(Default: ``None``)
overwrite (bool, optional): overwrite existing files (Default: ``False``)
Returns:
list: List of paths to extracted files even if not overwritten.
"""
if to_path is None:
to_path = os.path.dirname(from_path)
try:
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file %s.", from_path)
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
return files
except tarfile.ReadError:
pass
try:
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file %s.", from_path)
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
return files
except zipfile.BadZipFile:
pass
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str):
"""Download dataset from kaggle.
Args:
dataset_path (str):
This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning'
dataset_name (str): Name of the folder the dataset will be saved in.
output_path (str): Path of the location you want the dataset folder to be saved to.
"""
data_path = os.path.join(output_path, dataset_name)
try:
import kaggle # pylint: disable=import-outside-toplevel
kaggle.api.authenticate()
print(f"""\nDownloading {dataset_name}...""")
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
except OSError:
print(
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
)
+126
View File
@@ -0,0 +1,126 @@
import os
from typing import Optional
from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
def download_ljspeech(path: str):
"""Download and extract LJSpeech dataset
Args:
path (str): path to the directory where the dataset will be stored.
"""
os.makedirs(path, exist_ok=True)
url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
def download_vctk(path: str, use_kaggle: Optional[bool] = False):
"""Download and extract VCTK dataset.
Args:
path (str): path to the directory where the dataset will be stored.
use_kaggle (bool, optional): Downloads vctk dataset from kaggle. Is generally faster. Defaults to False.
"""
if use_kaggle:
download_kaggle_dataset("mfekadu/english-multispeaker-corpus-for-voice-cloning", "VCTK", path)
else:
os.makedirs(path, exist_ok=True)
url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
def download_tweb(path: str):
"""Download and extract Tweb dataset
Args:
path (str): Path to the directory where the dataset will be stored.
"""
download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path)
def download_libri_tts(path: str, subset: Optional[str] = "all"):
"""Download and extract libri tts dataset.
Args:
path (str): Path to the directory where the dataset will be stored.
subset (str, optional): Name of the subset to download. If you only want to download a certain
portion specify it here. Defaults to 'all'.
"""
subset_dict = {
"libri-tts-clean-100": "http://www.openslr.org/resources/60/train-clean-100.tar.gz",
"libri-tts-clean-360": "http://www.openslr.org/resources/60/train-clean-360.tar.gz",
"libri-tts-other-500": "http://www.openslr.org/resources/60/train-other-500.tar.gz",
"libri-tts-dev-clean": "http://www.openslr.org/resources/60/dev-clean.tar.gz",
"libri-tts-dev-other": "http://www.openslr.org/resources/60/dev-other.tar.gz",
"libri-tts-test-clean": "http://www.openslr.org/resources/60/test-clean.tar.gz",
"libri-tts-test-other": "http://www.openslr.org/resources/60/test-other.tar.gz",
}
os.makedirs(path, exist_ok=True)
if subset == "all":
for sub, val in subset_dict.items():
print(f" > Downloading {sub}...")
download_url(val, path)
basename = os.path.basename(val)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
print(" > All subsets downloaded")
else:
url = subset_dict[subset]
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
def download_thorsten_de(path: str):
"""Download and extract Thorsten german male voice dataset.
Args:
path (str): Path to the directory where the dataset will be stored.
"""
os.makedirs(path, exist_ok=True)
url = "https://www.openslr.org/resources/95/thorsten-de_v02.tgz"
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
def download_mailabs(path: str, language: str = "english"):
"""Download and extract Mailabs dataset.
Args:
path (str): Path to the directory where the dataset will be stored.
language (str): Language subset to download. Defaults to english.
"""
language_dict = {
"english": "https://data.solak.de/data/Training/stt_tts/en_US.tgz",
"german": "https://data.solak.de/data/Training/stt_tts/de_DE.tgz",
"french": "https://data.solak.de/data/Training/stt_tts/fr_FR.tgz",
"italian": "https://data.solak.de/data/Training/stt_tts/it_IT.tgz",
"spanish": "https://data.solak.de/data/Training/stt_tts/es_ES.tgz",
}
os.makedirs(path, exist_ok=True)
url = language_dict[language]
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
+239
View File
@@ -0,0 +1,239 @@
# -*- coding: utf-8 -*-
import datetime
import importlib
import logging
import os
import re
import subprocess
import sys
from pathlib import Path
from typing import Dict
import fsspec
import torch
def to_cuda(x: torch.Tensor) -> torch.Tensor:
if x is None:
return None
if torch.is_tensor(x):
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return x
def get_cuda():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return use_cuda, device
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n") if line.startswith("*"))
current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
except (FileNotFoundError, StopIteration) as e:
current = "unknown"
return current
def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
# try:
# subprocess.check_output(['git', 'diff-index', '--quiet',
# 'HEAD']) # Verify client is clean
# except:
# raise RuntimeError(
# " !! Commit before training to get the commit hash.")
try:
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
# Not copying .git folder into docker container
except (subprocess.CalledProcessError, FileNotFoundError):
commit = "0000000"
return commit
def get_experiment_folder_path(root_path, model_name):
"""Get an experiment folder path with the current date and time"""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
return output_folder
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
fs = fsspec.get_mapper(experiment_path).fs
checkpoint_files = fs.glob(experiment_path + "/*.pth")
if not checkpoint_files:
if fs.exists(experiment_path):
fs.rm(experiment_path, recursive=True)
print(" ! Run is removed from {}".format(experiment_path))
else:
print(" ! Run is kept in {}".format(experiment_path))
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
text = text.replace("Tts", "TTS")
text = text.replace("vc", "VC")
return text
def find_module(module_path: str, module_name: str) -> object:
module_name = module_name.lower()
module = importlib.import_module(module_path + "." + module_name)
class_name = to_camel(module_name)
return getattr(module, class_name)
def import_class(module_path: str) -> object:
"""Import a class from a module path.
Args:
module_path (str): The module path of the class.
Returns:
object: The imported class.
"""
class_name = module_path.split(".")[-1]
module_path = ".".join(module_path.split(".")[:-1])
module = importlib.import_module(module_path)
return getattr(module, class_name)
def get_import_path(obj: object) -> str:
"""Get the import path of a class.
Args:
obj (object): The class object.
Returns:
str: The import path of the class.
"""
return ".".join([type(obj).__module__, type(obj).__name__])
def get_user_data_dir(appname):
TTS_HOME = os.environ.get("TTS_HOME")
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
if TTS_HOME is not None:
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
elif XDG_DATA_HOME is not None:
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
elif sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel
key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == "darwin":
ans = Path("~/Library/Application Support/").expanduser()
else:
ans = Path.home().joinpath(".local/share")
return ans.joinpath(appname)
def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
# 3. skip reinit layers
if c.has("reinit_layers") and c.reinit_layers is not None:
for reinit_layer_name in c.reinit_layers:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
return model_dict
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
"""Format kwargs to hande auxilary inputs to models.
Args:
def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`.
kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model.
Returns:
Dict: arguments with formatted auxilary inputs.
"""
kwargs = kwargs.copy()
for name in def_args:
if name not in kwargs or kwargs[name] is None:
kwargs[name] = def_args[name]
return kwargs
class KeepAverage:
def __init__(self):
self.avg_values = {}
self.iters = {}
def __getitem__(self, key):
return self.avg_values[key]
def items(self):
return self.avg_values.items()
def add_value(self, name, init_val=0, init_iter=0):
self.avg_values[name] = init_val
self.iters[name] = init_iter
def update_value(self, name, value, weighted_avg=False):
if name not in self.avg_values:
# add value if not exist before
self.add_value(name, init_val=value)
else:
# else update existing value
if weighted_avg:
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
self.iters[name] += 1
else:
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
self.iters[name] += 1
self.avg_values[name] /= self.iters[name]
def add_values(self, name_dict):
for key, value in name_dict.items():
self.add_value(key, init_val=value)
def update_values(self, value_dict):
for key, value in value_dict.items():
self.update_value(key, value)
def get_timestamp():
return datetime.now().strftime("%y%m%d-%H%M%S")
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
lg = logging.getLogger(logger_name)
formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S")
lg.setLevel(level)
if tofile:
log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp()))
fh = logging.FileHandler(log_file, mode="w")
fh.setFormatter(formatter)
lg.addHandler(fh)
if screen:
sh = logging.StreamHandler()
sh.setFormatter(formatter)
lg.addHandler(sh)
+70
View File
@@ -0,0 +1,70 @@
import os
import pickle as pickle_tts
from typing import Any, Callable, Dict, Union
import fsspec
import torch
from TTS.utils.generic_utils import get_user_data_dir
class RenamingUnpickler(pickle_tts.Unpickler):
"""Overload default pickler to solve module renaming problem"""
def find_class(self, module, name):
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
class AttrDict(dict):
"""A custom dict which converts dict keys
to class attributes"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
cache: bool = True,
**kwargs,
) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
map_location: torch.device or str.
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
is_local = os.path.isdir(path) or os.path.isfile(path)
if cache and not is_local:
with fsspec.open(
f"filecache::{path}",
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
mode="rb",
) as f:
return torch.load(f, map_location=map_location, **kwargs)
else:
with fsspec.open(path, "rb") as f:
return torch.load(f, map_location=map_location, **kwargs)
def load_checkpoint(
model, checkpoint_path, use_cuda=False, eval=False, cache=False
): # pylint: disable=redefined-builtin
try:
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
model.load_state_dict(state["model"])
if use_cuda:
model.cuda()
if eval:
model.eval()
return model, state
+621
View File
@@ -0,0 +1,621 @@
import json
import os
import re
import tarfile
import zipfile
from pathlib import Path
from shutil import copyfile, rmtree
from typing import Dict, List, Tuple
import fsspec
import requests
from tqdm import tqdm
from TTS.config import load_config, read_json_with_comments
from TTS.utils.generic_utils import get_user_data_dir
LICENSE_URLS = {
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
"mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
"mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/",
"mit": "https://choosealicense.com/licenses/mit/",
"apache 2.0": "https://choosealicense.com/licenses/apache-2.0/",
"apache2": "https://choosealicense.com/licenses/apache-2.0/",
"cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/",
"cpml": "https://coqui.ai/cpml.txt",
}
class ModelManager(object):
tqdm_progress = None
"""Manage TTS models defined in .models.json.
It provides an interface to list and download
models defines in '.model.json'
Models are downloaded under '.TTS' folder in the user's
home path.
Args:
models_file (str): path to .model.json file. Defaults to None.
output_prefix (str): prefix to `tts` to download models. Defaults to None
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
verbose (bool): print info. Defaults to True.
"""
def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True):
super().__init__()
self.progress_bar = progress_bar
self.verbose = verbose
if output_prefix is None:
self.output_prefix = get_user_data_dir("tts")
else:
self.output_prefix = os.path.join(output_prefix, "tts")
self.models_dict = None
if models_file is not None:
self.read_models_file(models_file)
else:
# try the default location
path = Path(__file__).parent / "../.models.json"
self.read_models_file(path)
def read_models_file(self, file_path):
"""Read .models.json as a dict
Args:
file_path (str): path to .models.json.
"""
self.models_dict = read_json_with_comments(file_path)
def _list_models(self, model_type, model_count=0):
if self.verbose:
print("\n Name format: type/language/dataset/model")
model_list = []
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
output_path = os.path.join(self.output_prefix, model_full_name)
if self.verbose:
if os.path.exists(output_path):
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else:
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
return model_list
def _list_for_model_type(self, model_type):
models_name_list = []
model_count = 1
models_name_list.extend(self._list_models(model_type, model_count))
return models_name_list
def list_models(self):
models_name_list = []
model_count = 1
for model_type in self.models_dict:
model_list = self._list_models(model_type, model_count)
models_name_list.extend(model_list)
return models_name_list
def model_info_by_idx(self, model_query):
"""Print the description of the model from .models.json file using model_idx
Args:
model_query (str): <model_tye>/<model_idx>
"""
model_name_list = []
model_type, model_query_idx = model_query.split("/")
try:
model_query_idx = int(model_query_idx)
if model_query_idx <= 0:
print("> model_query_idx should be a positive integer!")
return
except:
print("> model_query_idx should be an integer!")
return
model_count = 0
if model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
else:
print(f"> model_type {model_type} does not exist in the list.")
return
if model_query_idx > model_count:
print(f"model query idx exceeds the number of available models [{model_count}] ")
else:
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")
def model_info_by_full_name(self, model_query_name):
"""Print the description of the model from .models.json file using model_full_name
Args:
model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
"""
model_type, lang, dataset, model = model_query_name.split("/")
if model_type in self.models_dict:
if lang in self.models_dict[model_type]:
if dataset in self.models_dict[model_type][lang]:
if model in self.models_dict[model_type][lang][dataset]:
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
)
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
)
else:
print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
else:
print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
else:
print(f"> lang {lang} does not exist for {model_type}.")
else:
print(f"> model_type {model_type} does not exist in the list.")
def list_tts_models(self):
"""Print all `TTS` models and return a list of model names
Format is `language/dataset/model`
"""
return self._list_for_model_type("tts_models")
def list_vocoder_models(self):
"""Print all the `vocoder` models and return a list of model names
Format is `language/dataset/model`
"""
return self._list_for_model_type("vocoder_models")
def list_vc_models(self):
"""Print all the voice conversion models and return a list of model names
Format is `language/dataset/model`
"""
return self._list_for_model_type("voice_conversion_models")
def list_langs(self):
"""Print all the available languages"""
print(" Name format: type/language")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
print(f" >: {model_type}/{lang} ")
def list_datasets(self):
"""Print all the datasets"""
print(" Name format: type/language/dataset")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
print(f" >: {model_type}/{lang}/{dataset}")
@staticmethod
def print_model_license(model_item: Dict):
"""Print the license of a model
Args:
model_item (dict): model item in the models.json
"""
if "license" in model_item and model_item["license"].strip() != "":
print(f" > Model's license - {model_item['license']}")
if model_item["license"].lower() in LICENSE_URLS:
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
else:
print(" > Check https://opensource.org/licenses for more info.")
else:
print(" > Model's license - No license information available")
def _download_github_model(self, model_item: Dict, output_path: str):
if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
def _download_hf_model(self, model_item: Dict, output_path: str):
if isinstance(model_item["hf_url"], list):
self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
else:
self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
def download_fairseq_model(self, model_name, output_path):
URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/"
_, lang, _, _ = model_name.split("/")
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
@staticmethod
def set_model_url(model_item: Dict):
model_item["model_url"] = None
if "github_rls_url" in model_item:
model_item["model_url"] = model_item["github_rls_url"]
elif "hf_url" in model_item:
model_item["model_url"] = model_item["hf_url"]
elif "fairseq" in model_item["model_name"]:
model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
elif "xtts" in model_item["model_name"]:
model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
return model_item
def _set_model_item(self, model_name):
# fetch model info from the dict
if "fairseq" in model_name:
model_type = "tts_models"
lang = model_name.split("/")[1]
model_item = {
"model_type": "tts_models",
"license": "CC BY-NC 4.0",
"default_vocoder": None,
"author": "fairseq",
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
}
model_item["model_name"] = model_name
elif "xtts" in model_name and len(model_name.split("/")) != 4:
# loading xtts models with only model name (e.g. xtts_v2.0.2)
# check model name has the version number with regex
version_regex = r"v\d+\.\d+\.\d+"
if re.search(version_regex, model_name):
model_version = model_name.split("_")[-1]
else:
model_version = "main"
model_type = "tts_models"
lang = "multilingual"
dataset = "multi-dataset"
model = model_name
model_item = {
"default_vocoder": None,
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": True,
"hf_url": [
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth",
],
}
else:
# get model from models.json
model_type, lang, dataset, model = model_name.split("/")
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
model_item = self.set_model_url(model_item)
return model_item, model_full_name, model, md5hash
@staticmethod
def ask_tos(model_full_path):
"""Ask the user to agree to the terms of service"""
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
print(" > You must confirm the following:")
print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"')
print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]')
answer = input(" | | > ")
if answer.lower() == "y":
with open(tos_path, "w", encoding="utf-8") as f:
f.write("I have read, understood and agreed to the Terms and Conditions.")
return True
return False
@staticmethod
def tos_agreed(model_item, model_full_path):
"""Check if the user has agreed to the terms of service"""
if "tos_required" in model_item and model_item["tos_required"]:
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1":
return True
return False
return True
def create_dir_and_download_model(self, model_name, model_item, output_path):
os.makedirs(output_path, exist_ok=True)
# handle TOS
if not self.tos_agreed(model_item, output_path):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
elif "github_rls_url" in model_item:
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item:
self._download_hf_model(model_item, output_path)
except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
def check_if_configs_are_equal(self, model_name, model_item, output_path):
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
config_local = json.load(f)
remote_url = None
for url in model_item["hf_url"]:
if "config.json" in url:
remote_url = url
break
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
config_remote = json.load(f)
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
'type/language/dataset/model'
e.g. 'tts_model/en/ljspeech/tacotron'
Every model must have the following files:
- *.pth : pytorch model checkpoint file.
- config.json : model config file.
- scale_stats.npy (if exist): scale values for preprocessing.
Args:
model_name (str): model name as explained above.
"""
model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
if md5sum is not None:
md5sum_file = os.path.join(output_path, "hash.md5")
if os.path.isfile(md5sum_file):
with open(md5sum_file, mode="r") as f:
if not f.read() == md5sum:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
else:
print(f" > {model_name} has been updated, clearing model cache...")
self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts" in model_name:
try:
self.check_if_configs_are_equal(model_name, model_item, output_path)
except:
pass
else:
print(f" > {model_name} is already downloaded.")
else:
self.create_dir_and_download_model(model_name, model_item, output_path)
# find downloaded files
output_model_path = output_path
output_config_path = None
if (
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json
self._update_paths(output_path, output_config_path)
return output_model_path, output_config_path, model_item
@staticmethod
def _find_files(output_path: str) -> Tuple[str, str]:
"""Find the model and config files in the output path
Args:
output_path (str): path to the model files
Returns:
Tuple[str, str]: path to the model file and config file
"""
model_file = None
config_file = None
for file_name in os.listdir(output_path):
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
model_file = os.path.join(output_path, file_name)
elif file_name == "config.json":
config_file = os.path.join(output_path, file_name)
if model_file is None:
raise ValueError(" [!] Model file not found in the output path")
if config_file is None:
raise ValueError(" [!] Config file not found in the output path")
return model_file, config_file
@staticmethod
def _find_speaker_encoder(output_path: str) -> str:
"""Find the speaker encoder file in the output path
Args:
output_path (str): path to the model files
Returns:
str: path to the speaker encoder file
"""
speaker_encoder_file = None
for file_name in os.listdir(output_path):
if file_name in ["model_se.pth", "model_se.pth.tar"]:
speaker_encoder_file = os.path.join(output_path, file_name)
return speaker_encoder_file
def _update_paths(self, output_path: str, config_path: str) -> None:
"""Update paths for certain files in config.json after download.
Args:
output_path (str): local path the model is downloaded to.
config_path (str): local config.json path.
"""
output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth")
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth")
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
speaker_encoder_model_path = self._find_speaker_encoder(output_path)
# update the scale_path.npy file path in the model config.json
self._update_path("audio.stats_path", output_stats_path, config_path)
# update the speakers.json file path in the model config.json to the current path
self._update_path("d_vector_file", output_d_vector_file_path, config_path)
self._update_path("d_vector_file", output_d_vector_file_pth_path, config_path)
self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
self._update_path("model_args.d_vector_file", output_d_vector_file_pth_path, config_path)
# update the speaker_ids.json file path in the model config.json to the current path
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
self._update_path("speakers_file", output_speaker_ids_file_pth_path, config_path)
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
self._update_path("model_args.speakers_file", output_speaker_ids_file_pth_path, config_path)
# update the speaker_encoder file path in the model config.json to the current path
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
@staticmethod
def _update_path(field_name, new_path, config_path):
"""Update the path in the model config.json for the current environment after download"""
if new_path and os.path.exists(new_path):
config = load_config(config_path)
field_names = field_name.split(".")
if len(field_names) > 1:
# field name points to a sub-level field
sub_conf = config
for fd in field_names[:-1]:
if fd in sub_conf:
sub_conf = sub_conf[fd]
else:
return
if isinstance(sub_conf[field_names[-1]], list):
sub_conf[field_names[-1]] = [new_path]
else:
sub_conf[field_names[-1]] = new_path
else:
# field name points to a top-level field
if not field_name in config:
return
if isinstance(config[field_name], list):
config[field_name] = [new_path]
else:
config[field_name] = new_path
config.save_json(config_path)
@staticmethod
def _download_zip_file(file_url, output_folder, progress_bar):
"""Download the github releases"""
# download the file
r = requests.get(file_url, stream=True)
# extract the file
try:
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_zip_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
ModelManager.tqdm_progress.update(len(data))
file.write(data)
with zipfile.ZipFile(temp_zip_name) as z:
z.extractall(output_folder)
os.remove(temp_zip_name) # delete zip after extract
except zipfile.BadZipFile:
print(f" > Error: Bad zip file - {file_url}")
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in z.namelist():
src_path = os.path.join(output_folder, file_path)
if os.path.isfile(src_path):
dst_path = os.path.join(output_folder, os.path.basename(file_path))
if src_path != dst_path:
copyfile(src_path, dst_path)
# remove redundant (hidden or not) folders
for file_path in z.namelist():
if os.path.isdir(os.path.join(output_folder, file_path)):
rmtree(os.path.join(output_folder, file_path))
@staticmethod
def _download_tar_file(file_url, output_folder, progress_bar):
"""Download the github releases"""
# download the file
r = requests.get(file_url, stream=True)
# extract the file
try:
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
ModelManager.tqdm_progress.update(len(data))
file.write(data)
with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder)
tar_names = t.getnames()
os.remove(temp_tar_name) # delete tar after extract
except tarfile.ReadError:
print(f" > Error: Bad tar file - {file_url}")
raise tarfile.ReadError # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
src_path = os.path.join(output_folder, tar_names[0], file_path)
dst_path = os.path.join(output_folder, os.path.basename(file_path))
if src_path != dst_path:
copyfile(src_path, dst_path)
# remove the extracted folder
rmtree(os.path.join(output_folder, tar_names[0]))
@staticmethod
def _download_model_files(file_urls, output_folder, progress_bar):
"""Download the github releases"""
for file_url in file_urls:
# download the file
r = requests.get(file_url, stream=True)
# extract the file
bease_filename = file_url.split("/")[-1]
temp_zip_name = os.path.join(output_folder, bease_filename)
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
with open(temp_zip_name, "wb") as file:
if progress_bar:
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
for data in r.iter_content(block_size):
if progress_bar:
ModelManager.tqdm_progress.update(len(data))
file.write(data)
@staticmethod
def _check_dict_key(my_dict, key):
if key in my_dict.keys() and my_dict[key] is not None:
if not isinstance(key, str):
return True
if isinstance(key, str) and len(my_dict[key]) > 0:
return True
return False
+105
View File
@@ -0,0 +1,105 @@
# modified from https://github.com/LiyuanLucasLiu/RAdam
import math
import torch
from torch.optim.optimizer import Optimizer
class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if eps < 0.0:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
self.degenerated_to_sgd = degenerated_to_sgd
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
for param in params:
if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]):
param["buffer"] = [[None, None, None] for _ in range(10)]
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]
)
super().__init__(params, defaults)
def __setstate__(self, state): # pylint: disable=useless-super-delegation
super().__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError("RAdam does not support sparse gradients")
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
state["step"] += 1
buffered = group["buffer"][int(state["step"] % 10)]
if state["step"] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state["step"]
beta2_t = beta2 ** state["step"]
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt(
(1 - beta2_t)
* (N_sma - 4)
/ (N_sma_max - 4)
* (N_sma - 2)
/ N_sma
* N_sma_max
/ (N_sma_max - 2)
) / (1 - beta1 ** state["step"])
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state["step"])
else:
step_size = -1
buffered[2] = step_size
# more conservative since it's an approximated value
if N_sma >= 5:
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
denom = exp_avg_sq.sqrt().add_(group["eps"])
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])
p.data.copy_(p_data_fp32)
return loss
+201
View File
@@ -0,0 +1,201 @@
import math
import random
from typing import Callable, List, Union
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
class SubsetSampler(Sampler):
"""
Samples elements sequentially from a given list of indices.
Args:
indices (list): a sequence of indices
"""
def __init__(self, indices):
super().__init__(indices)
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
class PerfectBatchSampler(Sampler):
"""
Samples a mini-batch of indices for a balanced class batching
Args:
dataset_items(list): dataset items to sample from.
classes (list): list of classes of dataset_items to sample from.
batch_size (int): total number of samples to be sampled in a mini-batch.
num_gpus (int): number of GPU in the data parallel mode.
shuffle (bool): if True, samples randomly, otherwise samples sequentially.
drop_last (bool): if True, drops last incomplete batch.
"""
def __init__(
self,
dataset_items,
classes,
batch_size,
num_classes_in_batch,
num_gpus=1,
shuffle=True,
drop_last=False,
label_key="class_name",
):
super().__init__(dataset_items)
assert (
batch_size % (num_classes_in_batch * num_gpus) == 0
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
label_indices = {}
for idx, item in enumerate(dataset_items):
label = item[label_key]
if label not in label_indices.keys():
label_indices[label] = [idx]
else:
label_indices[label].append(idx)
if shuffle:
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
else:
self._samplers = [SubsetSampler(label_indices[key]) for key in classes]
self._batch_size = batch_size
self._drop_last = drop_last
self._dp_devices = num_gpus
self._num_classes_in_batch = num_classes_in_batch
def __iter__(self):
batch = []
if self._num_classes_in_batch != len(self._samplers):
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
else:
valid_samplers_idx = None
iters = [iter(s) for s in self._samplers]
done = False
while True:
b = []
for i, it in enumerate(iters):
if valid_samplers_idx is not None and i not in valid_samplers_idx:
continue
idx = next(it, None)
if idx is None:
done = True
break
b.append(idx)
if done:
break
batch += b
if len(batch) == self._batch_size:
yield batch
batch = []
if valid_samplers_idx is not None:
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
if not self._drop_last:
if len(batch) > 0:
groups = len(batch) // self._num_classes_in_batch
if groups % self._dp_devices == 0:
yield batch
else:
batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
if len(batch) > 0:
yield batch
def __len__(self):
class_batch_size = self._batch_size // self._num_classes_in_batch
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
def identity(x):
return x
class SortedSampler(Sampler):
"""Samples elements sequentially, always in the same order.
Taken from https://github.com/PetrochukM/PyTorch-NLP
Args:
data (iterable): Iterable data.
sort_key (callable): Specifies a function of one argument that is used to extract a
numerical comparison key from each list element.
Example:
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
"""
def __init__(self, data, sort_key: Callable = identity):
super().__init__(data)
self.data = data
self.sort_key = sort_key
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
zip_ = sorted(zip_, key=lambda r: r[1])
self.sorted_indexes = [item[0] for item in zip_]
def __iter__(self):
return iter(self.sorted_indexes)
def __len__(self):
return len(self.data)
class BucketBatchSampler(BatchSampler):
"""Bucket batch sampler
Adapted from https://github.com/PetrochukM/PyTorch-NLP
Args:
sampler (torch.data.utils.sampler.Sampler):
batch_size (int): Size of mini-batch.
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
than `batch_size`.
data (list): List of data samples.
sort_key (callable, optional): Callable to specify a comparison key for sorting.
bucket_size_multiplier (int, optional): Buckets are of size
`batch_size * bucket_size_multiplier`.
Example:
>>> sampler = WeightedRandomSampler(weights, len(weights))
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True)
"""
def __init__(
self,
sampler,
data,
batch_size,
drop_last,
sort_key: Union[Callable, List] = identity,
bucket_size_multiplier=100,
):
super().__init__(sampler, batch_size, drop_last)
self.data = data
self.sort_key = sort_key
_bucket_size = batch_size * bucket_size_multiplier
if hasattr(sampler, "__len__"):
_bucket_size = min(_bucket_size, len(sampler))
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
def __iter__(self):
for idxs in self.bucket_sampler:
bucket_data = [self.data[idx] for idx in idxs]
sorted_sampler = SortedSampler(bucket_data, self.sort_key)
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
sorted_idxs = [idxs[i] for i in batch_idx]
yield sorted_idxs
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
return math.ceil(len(self.sampler) / self.batch_size)
+505
View File
@@ -0,0 +1,505 @@
import os
import time
from typing import List
import numpy as np
import pysbd
import torch
from torch import nn
from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.models.vits import Vits
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.vc.models import setup_model as setup_vc_model
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
class Synthesizer(nn.Module):
def __init__(
self,
tts_checkpoint: str = "",
tts_config_path: str = "",
tts_speakers_file: str = "",
tts_languages_file: str = "",
vocoder_checkpoint: str = "",
vocoder_config: str = "",
encoder_checkpoint: str = "",
encoder_config: str = "",
vc_checkpoint: str = "",
vc_config: str = "",
model_dir: str = "",
voice_dir: str = None,
use_cuda: bool = False,
) -> None:
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
model and synthesize speech from the provided text.
The text is divided into a list of sentences using `pysbd` and synthesize
speech on each sentence separately.
If you have certain special characters in your text, you need to handle
them before providing the text to Synthesizer.
TODO: set the segmenter based on the source language
Args:
tts_checkpoint (str, optional): path to the tts model file.
tts_config_path (str, optional): path to the tts config file.
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`,
encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`,
vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`,
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
"""
super().__init__()
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
self.tts_languages_file = tts_languages_file
self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config
self.encoder_checkpoint = encoder_checkpoint
self.encoder_config = encoder_config
self.vc_checkpoint = vc_checkpoint
self.vc_config = vc_config
self.use_cuda = use_cuda
self.tts_model = None
self.vocoder_model = None
self.vc_model = None
self.speaker_manager = None
self.tts_speakers = {}
self.language_manager = None
self.num_languages = 0
self.tts_languages = {}
self.d_vector_dim = 0
self.seg = self._get_segmenter("en")
self.use_cuda = use_cuda
self.voice_dir = voice_dir
if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
if tts_checkpoint:
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
if vc_checkpoint:
self._load_vc(vc_checkpoint, vc_config, use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir:
if "fairseq" in model_dir:
self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
else:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod
def _get_segmenter(lang: str):
"""get the sentence segmenter for the given language.
Args:
lang (str): target language code.
Returns:
[type]: [description]
"""
return pysbd.Segmenter(language=lang, clean=True)
def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None:
"""Load the voice conversion model.
1. Load the model config.
2. Init the model from the config.
3. Load the model weights.
4. Move the model to the GPU if CUDA is enabled.
Args:
vc_checkpoint (str): path to the model checkpoint.
tts_config_path (str): path to the model config file.
use_cuda (bool): enable/disable CUDA use.
"""
# pylint: disable=global-statement
self.vc_config = load_config(vc_config_path)
self.vc_model = setup_vc_model(config=self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
if use_cuda:
self.vc_model.cuda()
def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the fairseq model from a directory.
We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory.
"""
self.tts_config = VitsConfig()
self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config
if use_cuda:
self.tts_model.cuda()
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory.
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
"""
config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config
self.tts_model = setup_tts_model(config)
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
if use_cuda:
self.tts_model.cuda()
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
"""Load the TTS model.
1. Load the model config.
2. Init the model from the config.
3. Load the model weights.
4. Move the model to the GPU if CUDA is enabled.
5. Init the speaker manager in the model.
Args:
tts_checkpoint (str): path to the model checkpoint.
tts_config_path (str): path to the model config file.
use_cuda (bool): enable/disable CUDA use.
"""
# pylint: disable=global-statement
self.tts_config = load_config(tts_config_path)
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
raise ValueError("Phonemizer is not defined in the TTS config.")
self.tts_model = setup_tts_model(config=self.tts_config)
if not self.encoder_checkpoint:
self._set_speaker_encoder_paths_from_tts_config()
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda:
self.tts_model.cuda()
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"):
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config, use_cuda)
def _set_speaker_encoder_paths_from_tts_config(self):
"""Set the encoder paths from the tts model config for models with speaker encoders."""
if hasattr(self.tts_config, "model_args") and hasattr(
self.tts_config.model_args, "speaker_encoder_config_path"
):
self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
"""Load the vocoder model.
1. Load the vocoder config.
2. Init the AudioProcessor for the vocoder.
3. Init the vocoder model from the config.
4. Move the model to the GPU if CUDA is enabled.
Args:
model_file (str): path to the model checkpoint.
model_config (str): path to the model config file.
use_cuda (bool): enable/disable CUDA use.
"""
self.vocoder_config = load_config(model_config)
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
if use_cuda:
self.vocoder_model.cuda()
def split_into_sentences(self, text) -> List[str]:
"""Split give text into sentences.
Args:
text (str): input text in string format.
Returns:
List[str]: list of sentences.
"""
return self.seg.segment(text)
def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None:
"""Save the waveform as a file.
Args:
wav (List[int]): waveform as a list of values.
path (str): output path to save the waveform.
pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
"""
# if tensor convert to numpy
if torch.is_tensor(wav):
wav = wav.cpu().numpy()
if isinstance(wav, list):
wav = np.array(wav)
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out)
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
return output_wav
def tts(
self,
text: str = "",
speaker_name: str = "",
language_name: str = "",
speaker_wav=None,
style_wav=None,
style_text=None,
reference_wav=None,
reference_speaker_name=None,
split_sentences: bool = True,
**kwargs,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
Args:
text (str): input text.
speaker_name (str, optional): speaker id for multi-speaker models. Defaults to "".
language_name (str, optional): language id for multi-language models. Defaults to "".
speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None.
style_wav ([type], optional): style waveform for GST. Defaults to None.
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
reference_speaker_name ([type], optional): speaker id of reference waveform. Defaults to None.
split_sentences (bool, optional): split the input text into sentences. Defaults to True.
**kwargs: additional arguments to pass to the TTS model.
Returns:
List[int]: [description]
"""
start_time = time.time()
wavs = []
if not text and not reference_wav:
raise ValueError(
"You need to define either `text` (for sythesis) or a `reference_wav` (for voice conversion) to use the Coqui TTS API."
)
if text:
sens = [text]
if split_sentences:
print(" > Text splitted to sentences.")
sens = self.split_into_sentences(text)
print(sens)
# handle multi-speaker
if "voice_dir" in kwargs:
self.voice_dir = kwargs["voice_dir"]
kwargs.pop("voice_dir")
speaker_embedding = None
speaker_id = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts":
if self.tts_config.use_d_vector_file:
# get the average speaker embedding from the saved d_vectors.
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
speaker_name, num_samples=None, randomize=False
)
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name]
# handle Neon models with single speaker.
elif len(self.tts_model.speaker_manager.name_to_id) == 1:
speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0]
elif not speaker_name and not speaker_wav:
raise ValueError(
" [!] Looks like you are using a multi-speaker model. "
"You need to define either a `speaker_idx` or a `speaker_wav` to use a multi-speaker model."
)
else:
speaker_embedding = None
else:
if speaker_name and self.voice_dir is None:
raise ValueError(
f" [!] Missing speakers.json file path for selecting speaker {speaker_name}."
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
)
# handle multi-lingual
language_id = None
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager")
and self.tts_model.language_manager is not None
and not self.tts_config.model == "xtts"
):
if len(self.tts_model.language_manager.name_to_id) == 1:
language_id = list(self.tts_model.language_manager.name_to_id.values())[0]
elif language_name and isinstance(language_name, str):
try:
language_id = self.tts_model.language_manager.name_to_id[language_name]
except KeyError as e:
raise ValueError(
f" [!] Looks like you use a multi-lingual model. "
f"Language {language_name} is not in the available languages: "
f"{self.tts_model.language_manager.name_to_id.keys()}."
) from e
elif not language_name:
raise ValueError(
" [!] Look like you use a multi-lingual model. "
"You need to define either a `language_name` or a `style_wav` to use a multi-lingual model."
)
else:
raise ValueError(
f" [!] Missing language_ids.json file path for selecting language {language_name}."
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
)
# compute a new d_vector from the given clip.
if (
speaker_wav is not None
and self.tts_model.speaker_manager is not None
and hasattr(self.tts_model.speaker_manager, "encoder_ap")
and self.tts_model.speaker_manager.encoder_ap is not None
):
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
vocoder_device = "cpu"
use_gl = self.vocoder_model is None
if not use_gl:
vocoder_device = next(self.vocoder_model.parameters()).device
if self.use_cuda:
vocoder_device = "cuda"
if not reference_wav: # not voice conversion
for sen in sens:
if hasattr(self.tts_model, "synthesize"):
outputs = self.tts_model.synthesize(
text=sen,
config=self.tts_config,
speaker_id=speaker_name,
voice_dirs=self.voice_dir,
d_vector=speaker_embedding,
speaker_wav=speaker_wav,
language=language_name,
**kwargs,
)
else:
# synthesize voice
outputs = synthesis(
model=self.tts_model,
text=sen,
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
speaker_id=speaker_id,
style_wav=style_wav,
style_text=style_text,
use_griffin_lim=use_gl,
d_vector=speaker_embedding,
language_id=language_id,
)
waveform = outputs["wav"]
if not use_gl:
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
scale_factor = [
1,
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl:
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()
waveform = waveform.squeeze()
# trim silence
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
waveform = trim_silence(waveform, self.tts_model.ap)
wavs += list(waveform)
wavs += [0] * 10000
else:
# get the speaker embedding or speaker id for the reference wav file
reference_speaker_embedding = None
reference_speaker_id = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
if reference_speaker_name and isinstance(reference_speaker_name, str):
if self.tts_config.use_d_vector_file:
# get the speaker embedding from the saved d_vectors.
reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name(
reference_speaker_name
)[0]
reference_speaker_embedding = np.array(reference_speaker_embedding)[
None, :
] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
reference_speaker_id = self.tts_model.speaker_manager.name_to_id[reference_speaker_name]
else:
reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(
reference_wav
)
outputs = transfer_voice(
model=self.tts_model,
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
reference_wav=reference_wav,
speaker_id=speaker_id,
d_vector=speaker_embedding,
use_griffin_lim=use_gl,
reference_speaker_id=reference_speaker_id,
reference_d_vector=reference_speaker_embedding,
)
waveform = outputs
if not use_gl:
mel_postnet_spec = outputs[0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
scale_factor = [
1,
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"):
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()
wavs = waveform.squeeze()
# compute stats
process_time = time.time() - start_time
audio_time = len(wavs) / self.tts_config.audio["sample_rate"]
print(f" > Processing time: {process_time}")
print(f" > Real-time factor: {process_time / audio_time}")
return wavs
+44
View File
@@ -0,0 +1,44 @@
import numpy as np
import torch
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
r"""Check model gradient against unexpected jumps and failures"""
skip_flag = False
if ignore_stopnet:
if not amp_opt_params:
grad_norm = torch.nn.utils.clip_grad_norm_(
[param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip
)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
else:
if not amp_opt_params:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
# compatibility with different torch versions
if isinstance(grad_norm, float):
if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
else:
if torch.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
return grad_norm, skip_flag
def gradual_training_scheduler(global_step, config):
"""Setup the gradual training schedule wrt number
of active GPUs"""
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
num_gpus = 1
new_values = None
# we set the scheduling wrt num_gpus
for values in config.gradual_training:
if global_step * num_gpus >= values[0]:
new_values = values
return new_values[1], new_values[2]
+88
View File
@@ -0,0 +1,88 @@
import torch
import torchaudio
def read_audio(path):
wav, sr = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
return wav.squeeze(0), sr
def resample_wav(wav, sr, new_sr):
wav = wav.unsqueeze(0)
transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr)
wav = transform(wav)
return wav.squeeze(0)
def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False):
factor = new_sr / vad_sr
new_timestamps = []
if just_begging_end and timestamps:
# get just the start and end timestamps
new_dict = {"start": int(timestamps[0]["start"] * factor), "end": int(timestamps[-1]["end"] * factor)}
new_timestamps.append(new_dict)
else:
for ts in timestamps:
# map to the new SR
new_dict = {"start": int(ts["start"] * factor), "end": int(ts["end"] * factor)}
new_timestamps.append(new_dict)
return new_timestamps
def get_vad_model_and_utils(use_cuda=False, use_onnx=False):
model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=use_onnx, force_onnx_cpu=True
)
if use_cuda:
model = model.cuda()
get_speech_timestamps, save_audio, _, _, collect_chunks = utils
return model, get_speech_timestamps, save_audio, collect_chunks
def remove_silence(
model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False
):
# get the VAD model and utils functions
model, get_speech_timestamps, _, collect_chunks = model_and_utils
# read ground truth wav and resample the audio for the VAD
try:
wav, gt_sample_rate = read_audio(audio_path)
except:
print(f"> ❗ Failed to read {audio_path}")
return None, False
# if needed, resample the audio for the VAD model
if gt_sample_rate != vad_sample_rate:
wav_vad = resample_wav(wav, gt_sample_rate, vad_sample_rate)
else:
wav_vad = wav
if use_cuda:
wav_vad = wav_vad.cuda()
# get speech timestamps from full audio file
speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768)
# map the current speech_timestamps to the sample rate of the ground truth audio
new_speech_timestamps = map_timestamps_to_new_sr(
vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end
)
# if have speech timestamps else save the wav
if new_speech_timestamps:
wav = collect_chunks(new_speech_timestamps, wav)
is_speech = True
else:
print(f"> The file {audio_path} probably does not have speech please check it !!")
is_speech = False
# save
torchaudio.save(out_path, wav[None, :], gt_sample_rate)
return out_path, is_speech