Text-to-Audio / processors /phone_extractor.py
yuancwang
init
5548515
raw
history blame
4.83 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from tqdm import tqdm
from text.g2p_module import G2PModule, LexiconModule
from text.symbol_table import SymbolTable
"""
phoneExtractor: extract phone from text
"""
class phoneExtractor:
def __init__(self, cfg, dataset_name=None, phone_symbol_file=None):
"""
Args:
cfg: config
dataset_name: name of dataset
"""
self.cfg = cfg
# phone symbols dict
self.phone_symbols = set()
# phone symbols dict file
if phone_symbol_file is not None:
self.phone_symbols_file = phone_symbol_file
elif dataset_name is not None:
self.dataset_name = dataset_name
self.phone_symbols_file = os.path.join(
cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.symbols_dict
)
# initialize g2p module
if cfg.preprocess.phone_extractor in [
"espeak",
"pypinyin",
"pypinyin_initials_finals",
]:
self.g2p_module = G2PModule(backend=cfg.preprocess.phone_extractor)
elif cfg.preprocess.phone_extractor == "lexicon":
assert cfg.preprocess.lexicon_path != ""
self.g2p_module = LexiconModule(cfg.preprocess.lexicon_path)
else:
print("No suppert to", cfg.preprocess.phone_extractor)
raise
def extract_phone(self, text):
"""
Extract phone from text
Args:
text: text of utterance
Returns:
phone_symbols: set of phone symbols
phone_seq: list of phone sequence of each utterance
"""
if self.cfg.preprocess.phone_extractor in [
"espeak",
"pypinyin",
"pypinyin_initials_finals",
]:
text = text.replace("”", '"').replace("“", '"')
phone = self.g2p_module.g2p_conversion(text=text)
self.phone_symbols.update(phone)
phone_seq = [phn for phn in phone]
elif self.cfg.preprocess.phone_extractor == "lexicon":
phone_seq = self.g2p_module.g2p_conversion(text)
phone = phone_seq
if not isinstance(phone_seq, list):
phone_seq = phone_seq.split()
return phone_seq
def save_dataset_phone_symbols_to_table(self):
# load and merge saved phone symbols
if os.path.exists(self.phone_symbols_file):
phone_symbol_dict_saved = SymbolTable.from_file(
self.phone_symbols_file
)._sym2id.keys()
self.phone_symbols.update(set(phone_symbol_dict_saved))
# save phone symbols
phone_symbol_dict = SymbolTable()
for s in sorted(list(self.phone_symbols)):
phone_symbol_dict.add(s)
phone_symbol_dict.to_file(self.phone_symbols_file)
def extract_utt_phone_sequence(cfg, metadata):
"""
Extract phone sequence from text
Args:
cfg: config
metadata: list of dict, each dict contains "Uid", "Text"
"""
dataset_name = cfg.dataset[0]
# output path
out_path = os.path.join(
cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.phone_dir
)
os.makedirs(out_path, exist_ok=True)
phone_extractor = phoneExtractor(cfg, dataset_name)
for utt in tqdm(metadata):
uid = utt["Uid"]
text = utt["Text"]
phone_seq = phone_extractor.extract_phone(text)
phone_path = os.path.join(out_path, uid + ".phone")
with open(phone_path, "w") as fin:
fin.write(" ".join(phone_seq))
if cfg.preprocess.phone_extractor != "lexicon":
phone_extractor.save_dataset_phone_symbols_to_table()
def save_all_dataset_phone_symbols_to_table(self, cfg, dataset):
# phone symbols dict
phone_symbols = set()
for dataset_name in dataset:
phone_symbols_file = os.path.join(
cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.symbols_dict
)
# load and merge saved phone symbols
assert os.path.exists(phone_symbols_file)
phone_symbol_dict_saved = SymbolTable.from_file(
phone_symbols_file
)._sym2id.keys()
phone_symbols.update(set(phone_symbol_dict_saved))
# save all phone symbols to each dataset
phone_symbol_dict = SymbolTable()
for s in sorted(list(phone_symbols)):
phone_symbol_dict.add(s)
for dataset_name in dataset:
phone_symbols_file = os.path.join(
cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.symbols_dict
)
phone_symbol_dict.to_file(phone_symbols_file)