Add files via upload
This commit is contained in:
@@ -0,0 +1,135 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
import fsspec
|
||||
import yaml
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config.shared_configs import *
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
def read_json_with_comments(json_path):
|
||||
"""for backward compat."""
|
||||
# fallback to json
|
||||
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
# handle comments but not urls with //
|
||||
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
|
||||
return json.loads(input_str)
|
||||
|
||||
def register_config(model_name: str) -> Coqpit:
|
||||
"""Find the right config for the given model name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: No matching config for the model name.
|
||||
|
||||
Returns:
|
||||
Coqpit: config class.
|
||||
"""
|
||||
config_class = None
|
||||
config_name = model_name + "_config"
|
||||
|
||||
# TODO: fix this
|
||||
if model_name == "xtts":
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
|
||||
config_class = XttsConfig
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"]
|
||||
for path in paths:
|
||||
try:
|
||||
config_class = find_module(path, config_name)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
if config_class is None:
|
||||
raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.")
|
||||
return config_class
|
||||
|
||||
|
||||
def _process_model_name(config_dict: Dict) -> str:
|
||||
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
|
||||
|
||||
Args:
|
||||
config_dict (Dict): A dictionary including the config fields.
|
||||
|
||||
Returns:
|
||||
str: Formatted modelname.
|
||||
"""
|
||||
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
|
||||
model_name = model_name.replace("_generator", "").replace("_discriminator", "")
|
||||
return model_name
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Coqpit:
|
||||
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
|
||||
to find the corresponding Config class. Then initialize the Config.
|
||||
|
||||
Args:
|
||||
config_path (str): path to the config file.
|
||||
|
||||
Raises:
|
||||
TypeError: given config file has an unknown type.
|
||||
|
||||
Returns:
|
||||
Coqpit: TTS config object.
|
||||
"""
|
||||
config_dict = {}
|
||||
ext = os.path.splitext(config_path)[1]
|
||||
if ext in (".yml", ".yaml"):
|
||||
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
elif ext == ".json":
|
||||
try:
|
||||
with fsspec.open(config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except json.decoder.JSONDecodeError:
|
||||
# backwards compat.
|
||||
data = read_json_with_comments(config_path)
|
||||
else:
|
||||
raise TypeError(f" [!] Unknown config file type {ext}")
|
||||
config_dict.update(data)
|
||||
model_name = _process_model_name(config_dict)
|
||||
config_class = register_config(model_name.lower())
|
||||
config = config_class()
|
||||
config.from_dict(config_dict)
|
||||
return config
|
||||
|
||||
|
||||
def check_config_and_model_args(config, arg_name, value):
|
||||
"""Check the give argument in `config.model_args` if exist or in `config` for
|
||||
the given value.
|
||||
|
||||
Return False if the argument does not exist in `config.model_args` or `config`.
|
||||
This is to patch up the compatibility between models with and without `model_args`.
|
||||
|
||||
TODO: Remove this in the future with a unified approach.
|
||||
"""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name] == value
|
||||
if hasattr(config, arg_name):
|
||||
return config[arg_name] == value
|
||||
return False
|
||||
|
||||
|
||||
def get_from_config_or_model_args(config, arg_name):
|
||||
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name]
|
||||
return config[arg_name]
|
||||
|
||||
|
||||
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
|
||||
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name]
|
||||
if hasattr(config, arg_name):
|
||||
return config[arg_name]
|
||||
return def_val
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,268 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List
|
||||
|
||||
from coqpit import Coqpit, check_argument
|
||||
from trainer import TrainerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseAudioConfig(Coqpit):
|
||||
"""Base config to definge audio processing parameters. It is used to initialize
|
||||
```TTS.utils.audio.AudioProcessor.```
|
||||
|
||||
Args:
|
||||
fft_size (int):
|
||||
Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024.
|
||||
|
||||
win_length (int):
|
||||
Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match
|
||||
```fft_size```. Defaults to 1024.
|
||||
|
||||
hop_length (int):
|
||||
Number of audio samples between adjacent STFT columns. Defaults to 1024.
|
||||
|
||||
frame_shift_ms (int):
|
||||
Set ```hop_length``` based on milliseconds and sampling rate.
|
||||
|
||||
frame_length_ms (int):
|
||||
Set ```win_length``` based on milliseconds and sampling rate.
|
||||
|
||||
stft_pad_mode (str):
|
||||
Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'.
|
||||
|
||||
sample_rate (int):
|
||||
Audio sampling rate. Defaults to 22050.
|
||||
|
||||
resample (bool):
|
||||
Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```.
|
||||
|
||||
preemphasis (float):
|
||||
Preemphasis coefficient. Defaults to 0.0.
|
||||
|
||||
ref_level_db (int): 20
|
||||
Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air.
|
||||
Defaults to 20.
|
||||
|
||||
do_sound_norm (bool):
|
||||
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
||||
|
||||
log_func (str):
|
||||
Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'.
|
||||
|
||||
do_trim_silence (bool):
|
||||
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
|
||||
|
||||
do_amp_to_db_mel (bool, optional):
|
||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||
|
||||
pitch_fmax (float, optional):
|
||||
Maximum frequency of the F0 frames. Defaults to ```640```.
|
||||
|
||||
pitch_fmin (float, optional):
|
||||
Minimum frequency of the F0 frames. Defaults to ```1```.
|
||||
|
||||
trim_db (int):
|
||||
Silence threshold used for silence trimming. Defaults to 45.
|
||||
|
||||
do_rms_norm (bool, optional):
|
||||
enable/disable RMS volume normalization when loading an audio file. Defaults to False.
|
||||
|
||||
db_level (int, optional):
|
||||
dB level used for rms normalization. The range is -99 to 0. Defaults to None.
|
||||
|
||||
power (float):
|
||||
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
|
||||
artifacts in the synthesized voice. Defaults to 1.5.
|
||||
|
||||
griffin_lim_iters (int):
|
||||
Number of Griffing Lim iterations. Defaults to 60.
|
||||
|
||||
num_mels (int):
|
||||
Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80.
|
||||
|
||||
mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices.
|
||||
It needs to be adjusted for a dataset. Defaults to 0.
|
||||
|
||||
mel_fmax (float):
|
||||
Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset.
|
||||
|
||||
spec_gain (int):
|
||||
Gain applied when converting amplitude to DB. Defaults to 20.
|
||||
|
||||
signal_norm (bool):
|
||||
enable/disable signal normalization. Defaults to True.
|
||||
|
||||
min_level_db (int):
|
||||
minimum db threshold for the computed melspectrograms. Defaults to -100.
|
||||
|
||||
symmetric_norm (bool):
|
||||
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else
|
||||
[0, k], Defaults to True.
|
||||
|
||||
max_norm (float):
|
||||
```k``` defining the normalization range. Defaults to 4.0.
|
||||
|
||||
clip_norm (bool):
|
||||
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||
|
||||
stats_path (str):
|
||||
Path to the computed stats file. Defaults to None.
|
||||
"""
|
||||
|
||||
# stft parameters
|
||||
fft_size: int = 1024
|
||||
win_length: int = 1024
|
||||
hop_length: int = 256
|
||||
frame_shift_ms: int = None
|
||||
frame_length_ms: int = None
|
||||
stft_pad_mode: str = "reflect"
|
||||
# audio processing parameters
|
||||
sample_rate: int = 22050
|
||||
resample: bool = False
|
||||
preemphasis: float = 0.0
|
||||
ref_level_db: int = 20
|
||||
do_sound_norm: bool = False
|
||||
log_func: str = "np.log10"
|
||||
# silence trimming
|
||||
do_trim_silence: bool = True
|
||||
trim_db: int = 45
|
||||
# rms volume normalization
|
||||
do_rms_norm: bool = False
|
||||
db_level: float = None
|
||||
# griffin-lim params
|
||||
power: float = 1.5
|
||||
griffin_lim_iters: int = 60
|
||||
# mel-spec params
|
||||
num_mels: int = 80
|
||||
mel_fmin: float = 0.0
|
||||
mel_fmax: float = None
|
||||
spec_gain: int = 20
|
||||
do_amp_to_db_linear: bool = True
|
||||
do_amp_to_db_mel: bool = True
|
||||
# f0 params
|
||||
pitch_fmax: float = 640.0
|
||||
pitch_fmin: float = 1.0
|
||||
# normalization params
|
||||
signal_norm: bool = True
|
||||
min_level_db: int = -100
|
||||
symmetric_norm: bool = True
|
||||
max_norm: float = 4.0
|
||||
clip_norm: bool = True
|
||||
stats_path: str = None
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
):
|
||||
"""Check config fields"""
|
||||
c = asdict(self)
|
||||
check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056)
|
||||
check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058)
|
||||
check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000)
|
||||
check_argument(
|
||||
"frame_length_ms",
|
||||
c,
|
||||
restricted=True,
|
||||
min_val=10,
|
||||
max_val=1000,
|
||||
alternative="win_length",
|
||||
)
|
||||
check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length")
|
||||
check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1)
|
||||
check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10)
|
||||
check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000)
|
||||
check_argument("power", c, restricted=True, min_val=1, max_val=5)
|
||||
check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000)
|
||||
|
||||
# normalization parameters
|
||||
check_argument("signal_norm", c, restricted=True)
|
||||
check_argument("symmetric_norm", c, restricted=True)
|
||||
check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000)
|
||||
check_argument("clip_norm", c, restricted=True)
|
||||
check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000)
|
||||
check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True)
|
||||
check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100)
|
||||
check_argument("do_trim_silence", c, restricted=True)
|
||||
check_argument("trim_db", c, restricted=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetConfig(Coqpit):
|
||||
"""Base config for TTS datasets.
|
||||
|
||||
Args:
|
||||
formatter (str):
|
||||
Formatter name that defines used formatter in ```TTS.tts.datasets.formatter```. Defaults to `""`.
|
||||
|
||||
dataset_name (str):
|
||||
Unique name for the dataset. Defaults to `""`.
|
||||
|
||||
path (str):
|
||||
Root path to the dataset files. Defaults to `""`.
|
||||
|
||||
meta_file_train (str):
|
||||
Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
|
||||
Defaults to `""`.
|
||||
|
||||
ignored_speakers (List):
|
||||
List of speakers IDs that are not used at the training. Default None.
|
||||
|
||||
language (str):
|
||||
Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`.
|
||||
|
||||
phonemizer (str):
|
||||
Phonemizer used for that dataset's language. By default it uses `DEF_LANG_TO_PHONEMIZER`. Defaults to `""`.
|
||||
|
||||
meta_file_val (str):
|
||||
Name of the dataset meta file that defines the instances used at validation.
|
||||
|
||||
meta_file_attn_mask (str):
|
||||
Path to the file that lists the attention mask files used with models that require attention masks to
|
||||
train the duration predictor.
|
||||
"""
|
||||
|
||||
formatter: str = ""
|
||||
dataset_name: str = ""
|
||||
path: str = ""
|
||||
meta_file_train: str = ""
|
||||
ignored_speakers: List[str] = None
|
||||
language: str = ""
|
||||
phonemizer: str = ""
|
||||
meta_file_val: str = ""
|
||||
meta_file_attn_mask: str = ""
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
):
|
||||
"""Check config fields"""
|
||||
c = asdict(self)
|
||||
check_argument("formatter", c, restricted=True)
|
||||
check_argument("path", c, restricted=True)
|
||||
check_argument("meta_file_train", c, restricted=True)
|
||||
check_argument("meta_file_val", c, restricted=False)
|
||||
check_argument("meta_file_attn_mask", c, restricted=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseTrainingConfig(TrainerConfig):
|
||||
"""Base config to define the basic 🐸TTS training parameters that are shared
|
||||
among all the models. It is based on ```Trainer.TrainingConfig```.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Name of the model that is used in the training.
|
||||
|
||||
num_loader_workers (int):
|
||||
Number of workers for training time dataloader.
|
||||
|
||||
num_eval_loader_workers (int):
|
||||
Number of workers for evaluation time dataloader.
|
||||
"""
|
||||
|
||||
model: str = None
|
||||
# dataloading
|
||||
num_loader_workers: int = 0
|
||||
num_eval_loader_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
Reference in New Issue
Block a user