Spaces:
Runtime error
Runtime error
import re | |
import cn2an | |
import opencc | |
from text.symbols import punctuation, sh_symbols | |
converter = opencc.OpenCC('text/lexicon/zaonhe.json') | |
def number_to_shanghainese(text): | |
def to_shanghainese(num): | |
num = cn2an.an2cn(num).replace('一十', '十').replace('二十', '廿').replace('二', '两') | |
return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num) | |
return re.sub(r'\d+(?:\.?\d+)?', lambda x: to_shanghainese(x.group()), text) | |
rep_map = { | |
":": ",", | |
";": ",", | |
",": ",", | |
"。": ".", | |
"!": "!", | |
"?": "?", | |
"\n": ".", | |
"·": ",", | |
"、": ",", | |
"...": "…", | |
"$": ".", | |
"“": "'", | |
"”": "'", | |
"‘": "'", | |
"’": "'", | |
"(": "'", | |
")": "'", | |
"(": "'", | |
")": "'", | |
"《": "'", | |
"》": "'", | |
"【": "'", | |
"】": "'", | |
"[": "'", | |
"]": "'", | |
"—": "-", | |
"~": "-", | |
"~": "-", | |
"「": "'", | |
"」": "'", | |
} | |
def replace_punctuation(text): | |
text = text.replace("嗯", "恩").replace("呣", "母") | |
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) | |
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) | |
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text) | |
return replaced_text | |
def valid_tone_char(char): | |
return ("SH_" + char) in sh_symbols \ | |
or char in punctuation \ | |
or char.isdigit() \ | |
or char.isspace() | |
def g2p(text): | |
phones, tones, word2ph = _g2p(text) | |
phones = ["_"] + phones + ["_"] | |
tones = [0] + tones + [0] | |
word2ph = [1] + word2ph + [1] | |
return phones, tones, word2ph | |
def _g2p(text): | |
phones = converter.convert(text).replace('-', '').replace('$', '') | |
phones = "".join([i if valid_tone_char(i) else '' for i in phones]) | |
phone_chars = [i for i in phones] | |
phones = [] | |
tones = [] | |
word2ph = [] | |
if len(phone_chars) == 0: | |
return phones, tones, word2ph | |
phone_start_pos = 0 | |
for pos in range(len(phone_chars)): | |
char = phone_chars[pos] | |
if char.isdigit(): | |
tone = int(char) | |
word2ph = word2ph + [pos - phone_start_pos] | |
for j in range(phone_start_pos, pos): | |
tones = tones + [tone] | |
phone_start_pos = pos + 1 | |
elif char in punctuation: | |
if pos != phone_start_pos: | |
word2ph = word2ph + [pos - phone_start_pos] | |
for j in range(phone_start_pos, pos): | |
tones = tones + [0] | |
pass | |
phones = phones + [char] | |
tones = tones + [0] | |
word2ph = word2ph + [1] | |
phone_start_pos = pos + 1 | |
else: | |
phones = phones + [char] | |
pass | |
last_phone_char = phone_chars[-1] | |
if not last_phone_char.isdigit() and last_phone_char not in punctuation: | |
word2ph = word2ph + [len(phone_chars) - phone_start_pos] | |
for j in range(phone_start_pos, len(phone_chars)): | |
tones = tones + [0] | |
pass | |
# phones 加前缀 'SH' | |
phones = ["SH_" + i if ("SH_" + i) in sh_symbols else i for i in phones] | |
assert len(tones) == len(phones) | |
assert sum(word2ph) == len(phones) | |
return phones, tones, word2ph | |
def text_normalize(text): | |
text = number_to_shanghainese(text.upper()) | |
text = replace_punctuation(text) | |
return text | |
def get_bert_feature(text, word2ph, device): | |
from text import shanghainese_bert | |
return shanghainese_bert.get_bert_feature(text, word2ph, device) | |