File size: 4,250 Bytes
8c70653 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import json
import os
import re
from typing import Dict
import fsspec
import yaml
from coqpit import Coqpit
from TTS.config.shared_configs import *
from TTS.utils.generic_utils import find_module
def read_json_with_comments(json_path):
"""for backward compat."""
# fallback to json
with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments
input_str = re.sub(r"\\\n", "", input_str)
input_str = re.sub(r"//.*\n", "\n", input_str)
data = json.loads(input_str)
return data
def register_config(model_name: str) -> Coqpit:
"""Find the right config for the given model name.
Args:
model_name (str): Model name.
Raises:
ModuleNotFoundError: No matching config for the model name.
Returns:
Coqpit: config class.
"""
config_class = None
config_name = model_name + "_config"
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"]
for path in paths:
try:
config_class = find_module(path, config_name)
except ModuleNotFoundError:
pass
if config_class is None:
raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.")
return config_class
def _process_model_name(config_dict: Dict) -> str:
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
Args:
config_dict (Dict): A dictionary including the config fields.
Returns:
str: Formatted modelname.
"""
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
model_name = model_name.replace("_generator", "").replace("_discriminator", "")
return model_name
def load_config(config_path: str) -> Coqpit:
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
to find the corresponding Config class. Then initialize the Config.
Args:
config_path (str): path to the config file.
Raises:
TypeError: given config file has an unknown type.
Returns:
Coqpit: TTS config object.
"""
config_dict = {}
ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"):
with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
elif ext == ".json":
try:
with fsspec.open(config_path, "r", encoding="utf-8") as f:
data = json.load(f)
except json.decoder.JSONDecodeError:
# backwards compat.
data = read_json_with_comments(config_path)
else:
raise TypeError(f" [!] Unknown config file type {ext}")
config_dict.update(data)
model_name = _process_model_name(config_dict)
config_class = register_config(model_name.lower())
config = config_class()
config.from_dict(config_dict)
return config
def check_config_and_model_args(config, arg_name, value):
"""Check the give argument in `config.model_args` if exist or in `config` for
the given value.
Return False if the argument does not exist in `config.model_args` or `config`.
This is to patch up the compatibility between models with and without `model_args`.
TODO: Remove this in the future with a unified approach.
"""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name] == value
if hasattr(config, arg_name):
return config[arg_name] == value
return False
def get_from_config_or_model_args(config, arg_name):
"""Get the given argument from `config.model_args` if exist or in `config`."""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name]
return config[arg_name]
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
"""Get the given argument from `config.model_args` if exist or in `config`."""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name]
if hasattr(config, arg_name):
return config[arg_name]
return def_val
|