|
import json |
|
import os |
|
import re |
|
from typing import Dict |
|
|
|
import fsspec |
|
import yaml |
|
from coqpit import Coqpit |
|
|
|
from TTS.config.shared_configs import * |
|
from TTS.utils.generic_utils import find_module |
|
|
|
|
|
def read_json_with_comments(json_path): |
|
"""for backward compat.""" |
|
|
|
with fsspec.open(json_path, "r", encoding="utf-8") as f: |
|
input_str = f.read() |
|
|
|
input_str = re.sub(r"\\\n", "", input_str) |
|
input_str = re.sub(r"//.*\n", "\n", input_str) |
|
data = json.loads(input_str) |
|
return data |
|
|
|
|
|
def register_config(model_name: str) -> Coqpit: |
|
"""Find the right config for the given model name. |
|
|
|
Args: |
|
model_name (str): Model name. |
|
|
|
Raises: |
|
ModuleNotFoundError: No matching config for the model name. |
|
|
|
Returns: |
|
Coqpit: config class. |
|
""" |
|
config_class = None |
|
config_name = model_name + "_config" |
|
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"] |
|
for path in paths: |
|
try: |
|
config_class = find_module(path, config_name) |
|
except ModuleNotFoundError: |
|
pass |
|
if config_class is None: |
|
raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.") |
|
return config_class |
|
|
|
|
|
def _process_model_name(config_dict: Dict) -> str: |
|
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names. |
|
|
|
Args: |
|
config_dict (Dict): A dictionary including the config fields. |
|
|
|
Returns: |
|
str: Formatted modelname. |
|
""" |
|
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"] |
|
model_name = model_name.replace("_generator", "").replace("_discriminator", "") |
|
return model_name |
|
|
|
|
|
def load_config(config_path: str) -> Coqpit: |
|
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name |
|
to find the corresponding Config class. Then initialize the Config. |
|
|
|
Args: |
|
config_path (str): path to the config file. |
|
|
|
Raises: |
|
TypeError: given config file has an unknown type. |
|
|
|
Returns: |
|
Coqpit: TTS config object. |
|
""" |
|
config_dict = {} |
|
ext = os.path.splitext(config_path)[1] |
|
if ext in (".yml", ".yaml"): |
|
with fsspec.open(config_path, "r", encoding="utf-8") as f: |
|
data = yaml.safe_load(f) |
|
elif ext == ".json": |
|
try: |
|
with fsspec.open(config_path, "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
except json.decoder.JSONDecodeError: |
|
|
|
data = read_json_with_comments(config_path) |
|
else: |
|
raise TypeError(f" [!] Unknown config file type {ext}") |
|
config_dict.update(data) |
|
model_name = _process_model_name(config_dict) |
|
config_class = register_config(model_name.lower()) |
|
config = config_class() |
|
config.from_dict(config_dict) |
|
return config |
|
|
|
|
|
def check_config_and_model_args(config, arg_name, value): |
|
"""Check the give argument in `config.model_args` if exist or in `config` for |
|
the given value. |
|
|
|
Return False if the argument does not exist in `config.model_args` or `config`. |
|
This is to patch up the compatibility between models with and without `model_args`. |
|
|
|
TODO: Remove this in the future with a unified approach. |
|
""" |
|
if hasattr(config, "model_args"): |
|
if arg_name in config.model_args: |
|
return config.model_args[arg_name] == value |
|
if hasattr(config, arg_name): |
|
return config[arg_name] == value |
|
return False |
|
|
|
|
|
def get_from_config_or_model_args(config, arg_name): |
|
"""Get the given argument from `config.model_args` if exist or in `config`.""" |
|
if hasattr(config, "model_args"): |
|
if arg_name in config.model_args: |
|
return config.model_args[arg_name] |
|
return config[arg_name] |
|
|
|
|
|
def get_from_config_or_model_args_with_default(config, arg_name, def_val): |
|
"""Get the given argument from `config.model_args` if exist or in `config`.""" |
|
if hasattr(config, "model_args"): |
|
if arg_name in config.model_args: |
|
return config.model_args[arg_name] |
|
if hasattr(config, arg_name): |
|
return config[arg_name] |
|
return def_val |
|
|