)
+ token_maps: "/home/ubuntu/001_PLBERT_JA/token_maps_jp_200k_0.pkl" # token map path
+
+ max_mel_length: 512 # max phoneme length
+
+ word_mask_prob: 0.15 # probability to mask the entire word
+ phoneme_mask_prob: 0.1 # probability to mask each phoneme
+ replace_prob: 0.2 # probablity to replace phonemes
+
+model_params:
+ vocab_size: 178
+ hidden_size: 768
+ num_attention_heads: 12
+ intermediate_size: 2048
+ max_position_embeddings: 512
+ num_hidden_layers: 12
+ dropout: 0.1
\ No newline at end of file
diff --git a/Utils/PLBERT/step_1050000.t7 b/Utils/PLBERT/step_1050000.t7
new file mode 100644
index 0000000000000000000000000000000000000000..060c878a6b1868d92f78d6956a077bfb3a7c4276
--- /dev/null
+++ b/Utils/PLBERT/step_1050000.t7
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66b36784a5a5523412d40c98767f81c4d9e97a8615df58f14e420b37bf024409
+size 1918057040
diff --git a/Utils/PLBERT/util.py b/Utils/PLBERT/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..91be427eae799c60aa4631f8c3c5467ea403f4bc
--- /dev/null
+++ b/Utils/PLBERT/util.py
@@ -0,0 +1,47 @@
+import os
+import yaml
+import torch
+from transformers import AlbertConfig, AlbertModel
+
+class CustomAlbert(AlbertModel):
+ def forward(self, *args, **kwargs):
+ # Call the original forward method
+ outputs = super().forward(*args, **kwargs)
+
+ # Only return the last_hidden_state
+ return outputs.last_hidden_state
+
+
+def load_plbert(log_dir):
+ config_path = os.path.join(log_dir, "config.yml")
+ plbert_config = yaml.safe_load(open(config_path))
+
+ albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
+ bert = CustomAlbert(albert_base_configuration)
+
+ files = os.listdir(log_dir)
+ ckpts = []
+ for f in os.listdir(log_dir):
+ if f.startswith("step_"): ckpts.append(f)
+
+ iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
+ iters = sorted(iters)[-1]
+
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
+ state_dict = checkpoint['net']
+ from collections import OrderedDict
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = k[7:] # remove `module.`
+ if name.startswith('encoder.'):
+ name = name[8:] # remove `encoder.`
+ new_state_dict[name] = v
+
+ # Check if 'embeddings.position_ids' exists before attempting to delete it
+ if not hasattr(bert.embeddings, 'position_ids'):
+ del new_state_dict["embeddings.position_ids"]
+
+
+ bert.load_state_dict(new_state_dict, strict=False)
+
+ return bert
diff --git a/Utils/__init__.py b/Utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/Utils/__init__.py
@@ -0,0 +1 @@
+
diff --git a/Utils/__pycache__/__init__.cpython-311.pyc b/Utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49feac9fb7367b25baeec177602a7aecca05e51f
Binary files /dev/null and b/Utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/Utils/phonemize/__pycache__/cotlet_phon.cpython-311.pyc b/Utils/phonemize/__pycache__/cotlet_phon.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6278e2cac7a62cebc594fd14cab4c59920687996
Binary files /dev/null and b/Utils/phonemize/__pycache__/cotlet_phon.cpython-311.pyc differ
diff --git a/Utils/phonemize/__pycache__/cotlet_phon_dir_backend.cpython-311.pyc b/Utils/phonemize/__pycache__/cotlet_phon_dir_backend.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48f626bf28675a71107af6a9225e77cee6376e1d
Binary files /dev/null and b/Utils/phonemize/__pycache__/cotlet_phon_dir_backend.cpython-311.pyc differ
diff --git a/Utils/phonemize/__pycache__/cotlet_utils.cpython-311.pyc b/Utils/phonemize/__pycache__/cotlet_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c9594d26dd44e3560703bc9f3f0e9bbc35f1ce8
Binary files /dev/null and b/Utils/phonemize/__pycache__/cotlet_utils.cpython-311.pyc differ
diff --git a/Utils/phonemize/__pycache__/mixed_phon.cpython-311.pyc b/Utils/phonemize/__pycache__/mixed_phon.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbe5b5c9a11c9d20e70d028cd8cd0b7c3120d5e7
Binary files /dev/null and b/Utils/phonemize/__pycache__/mixed_phon.cpython-311.pyc differ
diff --git a/Utils/phonemize/cotlet_phon.py b/Utils/phonemize/cotlet_phon.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc2486917038af4007118ea3cee92306d86db223
--- /dev/null
+++ b/Utils/phonemize/cotlet_phon.py
@@ -0,0 +1,162 @@
+from Utils.phonemize.cotlet_utils import *
+import cutlet
+
+katsu = cutlet.Cutlet(ensure_ascii=False)
+katsu.use_foreign_spelling = False
+
+def process_japanese_text(ml):
+ # Check for small characters and replace them
+ if any(char in ml for char in "ぁぃぅぇぉ"):
+
+ ml = ml.replace("ぁ", "あ")
+ ml = ml.replace("ぃ", "い")
+ ml = ml.replace("ぅ", "う")
+ ml = ml.replace("ぇ", "え")
+ ml = ml.replace("ぉ", "お")
+
+ # Initialize Cutlet for romaji conversion
+
+ # Convert to romaji and apply transformations
+ # output = katsu.romaji(ml, capitalize=False).lower()
+
+ output = katsu.romaji(apply_transformations(alphabetreading(ml)), capitalize=False).lower()
+
+
+ # Replace specific romaji sequences
+ if 'j' in output:
+ output = output.replace('j', "dʑ")
+ if 'tt' in output:
+ output = output.replace('tt', "ʔt")
+ if 't t' in output:
+ output = output.replace('t t', "ʔt")
+ if ' ʔt' in output:
+ output = output.replace(' ʔt', "ʔt")
+ if 'ssh' in output:
+ output = output.replace('ssh', "ɕɕ")
+
+ # Convert romaji to IPA
+ output = Roma2IPA(convert_numbers_in_string(output))
+
+
+ output = hira2ipa(output)
+
+ # Apply additional transformations
+ output = replace_chars_2(output)
+ output = replace_repeated_chars(replace_tashdid_2(output))
+ output = nasal_mapper(output)
+
+ # Final adjustments
+ if " ɴ" in output:
+ output = output.replace(" ɴ", "ɴ")
+
+ if ' neɽitai ' in output:
+ output = output.replace(' neɽitai ', "naɽitai")
+
+ if 'harɯdʑisama' in output:
+ output = output.replace('harɯdʑisama', "arɯdʑisama")
+
+
+ if "ki ni ɕinai" in output:
+ output = re.sub(r'(?= 3:
+# return pattern + "~~~"
+# return match.group(0)
+
+# # Pattern for space-separated repeats
+# pattern1 = r'((?:\S+\s+){1,5}?)(?:\1){2,}'
+# # Pattern for continuous repeats without spaces
+# pattern2 = r'(.+?)\1{2,}'
+
+# text = re.sub(pattern1, replace_repeats, text)
+# text = re.sub(pattern2, replace_repeats, text)
+# return text
+
+
+def replace_repeating_a(output):
+ # Define patterns and their replacements
+ patterns = [
+ (r'(aː)\s*\1+\s*', r'\1~'), # Replace repeating "aː" with "aː~~"
+ (r'(aːa)\s*aː', r'\1~'), # Replace "aːa aː" with "aː~~"
+ (r'aːa', r'aː~'), # Replace "aːa" with "aː~"
+ (r'naː\s*aː', r'naː~'), # Replace "naː aː" with "naː~"
+ (r'(oː)\s*\1+\s*', r'\1~'), # Replace repeating "oː" with "oː~~"
+ (r'(oːo)\s*oː', r'\1~'), # Replace "oːo oː" with "oː~~"
+ (r'oːo', r'oː~'), # Replace "oːo" with "oː~"
+ (r'(eː)\s*\1+\s*', r'\1~'),
+ (r'(e)\s*\1+\s*', r'\1~'),
+ (r'(eːe)\s*eː', r'\1~'),
+ (r'eːe', r'eː~'),
+ (r'neː\s*eː', r'neː~'),
+ ]
+
+
+ # Apply each pattern to the output
+ for pattern, replacement in patterns:
+ output = re.sub(pattern, replacement, output)
+
+ return output
+
+def phonemize(text):
+
+ # if "っ" in text:
+ # text = text.replace("っ","ʔ")
+
+ output = post_fix(process_japanese_text(text))
+ #output = text
+
+ if " ɴ" in output:
+ output = output.replace(" ɴ", "ɴ")
+ if "y" in output:
+ output = output.replace("y", "j")
+ if "ɯa" in output:
+ output = output.replace("ɯa", "wa")
+
+ if "a aː" in output:
+ output = output.replace("a aː","a~")
+ if "a a" in output:
+ output = output.replace("a a","a~")
+
+
+
+
+
+ output = replace_repeating_a((output))
+ output = re.sub(r'\s+~', '~', output)
+
+ if "oː~o oː~ o" in output:
+ output = output.replace("oː~o oː~ o","oː~~~~~~")
+ if "aː~aː" in output:
+ output = output.replace("aː~aː","aː~~~")
+ if "oɴ naː" in output:
+ output = output.replace("oɴ naː","onnaː")
+ if "aː~~ aː" in output:
+ output = output.replace("aː~~ aː","aː~~~~")
+ if "oː~o" in output:
+ output = output.replace("oː~o","oː~~")
+ if "oː~~o o" in output:
+ output = output.replace("oː~~o o","oː~~~~") # yeah I'm too tired to learn regex how did you know
+
+ output = random_space_fix(output)
+ output = random_sym_fix(output) # fixing some symbols, if they have a specific white space such as miku& sakura -> miku ando sakura
+ output = random_sym_fix_no_space(output) # same as above but for those without white space such as miku&sakura -> miku ando sakura
+ # if "ɯ" in output:
+ # output = output.replace("ɯ","U")ss
+ # if "ʔ" in output:
+ # output = output.replace("ʔ","!")
+
+ return output.lstrip()
+# def process_row(row):
+# return {'phonemes': [phonemize(word) for word in row['phonemes']]}
diff --git a/Utils/phonemize/cotlet_phon_dir_backend.py b/Utils/phonemize/cotlet_phon_dir_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b754b41780c14ea152c15392849b372d46643b
--- /dev/null
+++ b/Utils/phonemize/cotlet_phon_dir_backend.py
@@ -0,0 +1,143 @@
+from Utils.phonemize.cotlet_utils import *
+import cutlet
+
+katsu = cutlet.Cutlet(ensure_ascii=False)
+katsu.use_foreign_spelling = False
+
+def process_latin_text(ml):
+ # Check for small characters and replace them
+
+ # Initialize Cutlet for romaji conversion
+
+ # Convert to romaji and apply transformations
+ # output = katsu.romaji(ml, capitalize=False).lower()
+
+ output = ml.lower()
+
+
+ # Replace specific romaji sequences
+ if 'j' in output:
+ output = output.replace('j', "dʑ")
+ if 'y' in output:
+ output = output.replace('y', "j")
+ if 'tt' in output:
+ output = output.replace('tt', "ʔt")
+ if 't t' in output:
+ output = output.replace('t t', "ʔt")
+ if ' ʔt' in output:
+ output = output.replace(' ʔt', "ʔt")
+ if 'ssh' in output:
+ output = output.replace('ssh', "ɕɕ")
+
+ # Convert romaji to IPA
+ output = Roma2IPA(convert_numbers_in_string(output))
+
+
+ output = hira2ipa(output)
+
+ # Apply additional transformations
+ output = replace_chars_2(output)
+ output = replace_repeated_chars(replace_tashdid_2(output))
+ output = nasal_mapper(output)
+
+ # Final adjustments
+ if " ɴ" in output:
+ output = output.replace(" ɴ", "ɴ")
+
+ if ' neɽitai ' in output:
+ output = output.replace(' neɽitai ', "naɽitai")
+
+ if 'harɯdʑisama' in output:
+ output = output.replace('harɯdʑisama', "arɯdʑisama")
+
+
+ if "ki ni ɕinai" in output:
+ output = re.sub(r'(? miku ando sakura
+ output = random_sym_fix_no_space(output) # same as above but for those without white space such as miku&sakura -> miku ando sakura
+ # if "ɯ" in output:
+ # output = output.replace("ɯ","U")ss
+ # if "ʔ" in output:
+ # output = output.replace("ʔ","!")
+
+ return output.lstrip()
+# def process_row(row):
+# return {'phonemes': [phonemize(word) for word in row['phonemes']]}
diff --git a/Utils/phonemize/cotlet_utils.py b/Utils/phonemize/cotlet_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa37b77c3e30ec750b4b8f26b30debdd9de2b7d2
--- /dev/null
+++ b/Utils/phonemize/cotlet_utils.py
@@ -0,0 +1,1022 @@
+formal_to_informal = {
+
+
+
+ 'ワタクシ': 'わたし',
+ 'チカコ':'しゅうこ',
+ "タノヒト":"ほかのひと",
+
+ # Add more mappings as needed
+}
+
+formal_to_informal2 = {
+
+ "たのひと":"ほかのひと",
+ "すうは": "かずは",
+
+
+ # Add more mappings as needed
+}
+
+formal_to_informal3 = {
+
+ "%":"%",
+ "@": "あっとさいん",
+ "$":"どる",
+ "#":"はっしゅたぐ",
+ "$":"どる",
+ "#":"はっしゅたぐ",
+ "何が":"なにが",
+
+ "何も":"なにも",
+ "何か":"なにか",
+ # "奏":"かなで",
+ "何は":"なにが",
+ "お父様":"おとうさま",
+ "お兄様":"おにいさま",
+ "何を":"なにを",
+ "良い":"いい",
+ "李衣菜":"りいな",
+ "志希":"しき",
+ "種":"たね",
+ "方々":"かたがた",
+ "颯":"はやて",
+ "茄子さん":"かこさん",
+ "茄子ちゃん":"かこちゃん",
+ "涼ちゃん":"りょうちゃん",
+ "涼さん":"りょうさん",
+ "紗枝":"さえ",
+ "文香":"ふみか",
+ "私":"わたし",
+ "周子":"しゅうこ",
+ "イェ":"いえ",
+ "可憐":"かれん",
+ "加蓮":"かれん",
+ "・":".",
+ # "方の":"かたの",
+ # "気に":"きに",
+ "唯さん":"ゆいさん",
+ "唯ちゃん":"ゆいちゃん",
+ "聖ちゃん":"ひじりちゃん",
+ "他の":"ほかの",
+ "他に":"ほかに",
+ "一生懸命":"いっしょうけんめい",
+ "楓さん":"かえでさん",
+ "楓ちゃん":"かえでちゃん",
+ "内から":"ないから",
+ "の下で":"のしたで",
+
+}
+
+
+mapper = dict([
+
+ ("仕方","しかた"),
+ ("明日","あした"),
+ ('私',"わたし"),
+ ("従妹","いとこ"),
+
+ ("1人","ひとり"),
+ ("2人","ふたり"),
+
+ ("一期","いちご"),
+ ("一会","いちえ"),
+
+ ("♪","!"),
+ ("?","?"),
+
+ ("どんな方","どんなかた"),
+ ("ふたり暮らし","ふたりぐらし"),
+
+ ("新年","しんねん"),
+ ("来年","らいねん"),
+ ("去年","きょねん"),
+ ("壮年","そうねん"),
+ ("今年","ことし"),
+
+ ("昨年","さくねん"),
+ ("本年","ほんねん"),
+ ("平年","へいねん"),
+ ("閏年","うるうどし"),
+ ("初年","しょねん"),
+ ("少年","しょうねん"),
+ ("多年","たねん"),
+ ("青年","せいねん"),
+ ("中年","ちゅうねん"),
+ ("老年","ろうねん"),
+ ("成年","せいねん"),
+ ("幼年","ようねん"),
+ ("前年","ぜんねん"),
+ ("元年","がんねん"),
+ ("経年","けいねん"),
+ ("当年","とうねん"),
+
+ ("明年","みょうねん"),
+ ("歳年","さいねん"),
+ ("数年","すうねん"),
+ ("半年","はんとし"),
+ ("後年","こうねん"),
+ ("実年","じつねん"),
+ ("年年","ねんねん"),
+ ("連年","れんねん"),
+ ("暦年","れきねん"),
+ ("各年","かくねん"),
+ ("全年","ぜんねん"),
+
+ ("年を","としを"),
+ ("年が","としが"),
+ ("年も","としも"),
+ ("年は","としは"),
+
+
+ ("奏ちゃん","かなでちゃん"),
+ ("負けず嫌い","まけずぎらい"),
+ ("貴方","あなた"),
+ ("貴女","あなた"),
+ ("貴男","あなた"),
+
+ ("その節","そのせつ"),
+
+ ("何し","なにし"),
+ ("何する","なにする"),
+
+ ("心さん","しんさん"),
+ ("心ちゃん","しんちゃん"),
+
+ ("乃々","のの"),
+
+ ("身体の","からだの"),
+ ("身体が","からだが"),
+ ("身体を","からだを"),
+ ("身体は","からだは"),
+ ("身体に","からだに"),
+ ("正念場","しょうねんば"),
+ ("言う","いう"),
+
+
+ ("一回","いっかい"),
+ ("一曲","いっきょく"),
+ ("一日","いちにち"),
+ ("一言","ひとこと"),
+ ("一杯","いっぱい"),
+
+
+ ("方が","ほうが"),
+ ("縦輪城","じゅうりんしろ"),
+ ("深息","しんそく"),
+ ("家人","かじん"),
+ ("お返し","おかえし"),
+ ("化物語","ばけものがたり"),
+ ("阿良々木暦","あららぎこよみ"),
+ ("何より","なにより")
+
+
+])
+
+
+# Merge all dictionaries into one
+all_transformations = {**formal_to_informal, **formal_to_informal2, **formal_to_informal3, **mapper}
+
+def apply_transformations(text, transformations = all_transformations):
+ for key, value in transformations.items():
+ text = text.replace(key, value)
+ return text
+
+import re
+
+def number_to_japanese(num):
+ if not isinstance(num, int) or num < 0 or num > 9999:
+ return "Invalid input"
+
+ digits = ["", "いち", "に", "さん", "よん", "ご", "ろく", "なな", "はち", "きゅう"]
+ tens = ["", "じゅう", "にじゅう", "さんじゅう", "よんじゅう", "ごじゅう", "ろくじゅう", "ななじゅう", "はちじゅう", "きゅうじゅう"]
+ hundreds = ["", "ひゃく", "にひゃく", "さんびゃく", "よんひゃく", "ごひゃく", "ろっぴゃく", "ななひゃく", "はっぴゃく", "きゅうひゃく"]
+ thousands = ["", "せん", "にせん", "さんぜん", "よんせん", "ごせん", "ろくせん", "ななせん", "はっせん", "きゅうせん"]
+
+ if num == 0:
+ return "ゼロ"
+
+ result = ""
+ if num >= 1000:
+ result += thousands[num // 1000]
+ num %= 1000
+ if num >= 100:
+ result += hundreds[num // 100]
+ num %= 100
+ if num >= 10:
+ result += tens[num // 10]
+ num %= 10
+ if num > 0:
+ result += digits[num]
+
+ return result
+
+def convert_numbers_in_string(input_string):
+ # Regular expression to find numbers in the string
+ number_pattern = re.compile(r'\d+')
+
+ # Function to replace numbers with their Japanese pronunciation
+ def replace_with_japanese(match):
+ num = int(match.group())
+ return number_to_japanese(num)
+
+ # Replace all occurrences of numbers in the string
+ converted_string = number_pattern.sub(replace_with_japanese, input_string)
+ return converted_string
+
+
+roma_mapper = dict([
+
+ ################################
+
+ ("my","mʲ"),
+ ("by","bʲ"),
+ ("ny","nʲ"),
+ ("ry","rʲ"),
+ ("si","sʲ"),
+ ("ky","kʲ"),
+ ("gy","gʲ"),
+ ("dy","dʲ"),
+ ("di","dʲ"),
+ ("fi","fʲ"),
+ ("fy","fʲ"),
+ ("ch","tɕ"),
+ ("sh","ɕ"),
+
+ ################################
+
+ ("a","a"),
+ ("i","i"),
+ ("u","ɯ"),
+ ("e","e"),
+ ("o","o"),
+ ("ka","ka"),
+ ("ki","ki"),
+ ("ku","kɯ"),
+ ("ke","ke"),
+ ("ko","ko"),
+ ("sa","sa"),
+ ("shi","ɕi"),
+ ("su","sɯ"),
+ ("se","se"),
+ ("so","so"),
+ ("ta","ta"),
+ ("chi","tɕi"),
+ ("tsu","tsɯ"),
+ ("te","te"),
+ ("to","to"),
+ ("na","na"),
+ ("ni","ni"),
+ ("nu","nɯ"),
+ ("ne","ne"),
+ ("no","no"),
+ ("ha","ha"),
+ ("hi","çi"),
+ ("fu","ɸɯ"),
+ ("he","he"),
+ ("ho","ho"),
+ ("ma","ma"),
+ ("mi","mi"),
+ ("mu","mɯ"),
+ ("me","me"),
+ ("mo","mo"),
+ ("ra","ɽa"),
+ ("ri","ɽi"),
+ ("ru","ɽɯ"),
+ ("re","ɽe"),
+ ("ro","ɽo"),
+ ("ga","ga"),
+ ("gi","gi"),
+ ("gu","gɯ"),
+ ("ge","ge"),
+ ("go","go"),
+ ("za","za"),
+ ("ji","dʑi"),
+ ("zu","zɯ"),
+ ("ze","ze"),
+ ("zo","zo"),
+ ("da","da"),
+
+
+ ("zu","zɯ"),
+ ("de","de"),
+ ("do","do"),
+ ("ba","ba"),
+ ("bi","bi"),
+ ("bu","bɯ"),
+ ("be","be"),
+ ("bo","bo"),
+ ("pa","pa"),
+ ("pi","pi"),
+ ("pu","pɯ"),
+ ("pe","pe"),
+ ("po","po"),
+ ("ya","ja"),
+ ("yu","jɯ"),
+ ("yo","jo"),
+ ("wa","wa"),
+
+
+
+
+ ("a","a"),
+ ("i","i"),
+ ("u","ɯ"),
+ ("e","e"),
+ ("o","o"),
+ ("wa","wa"),
+ ("o","o"),
+
+
+ ("wo","o")])
+
+nasal_sound = dict([
+ # before m, p, b
+ ("ɴm","mm"),
+ ("ɴb", "mb"),
+ ("ɴp", "mp"),
+
+ # before k, g
+ ("ɴk","ŋk"),
+ ("ɴg", "ŋg"),
+
+ # before t, d, n, s, z, ɽ
+ ("ɴt","nt"),
+ ("ɴd", "nd"),
+ ("ɴn","nn"),
+ ("ɴs", "ns"),
+ ("ɴz","nz"),
+ ("ɴɽ", "nɽ"),
+
+ ("ɴɲ", "ɲɲ"),
+
+])
+
+def Roma2IPA(text):
+ orig = text
+
+ for k, v in roma_mapper.items():
+ text = text.replace(k, v)
+
+ return text
+
+def nasal_mapper(text):
+ orig = text
+
+
+ for k, v in nasal_sound.items():
+ text = text.replace(k, v)
+
+ return text
+
+def alphabetreading(text):
+ alphabet_dict = {"A": "エイ",
+ "B": "ビー",
+ "C": "シー",
+ "D": "ディー",
+ "E": "イー",
+ "F": "エフ",
+ "G": "ジー",
+ "H": "エイチ",
+ "I":"アイ",
+ "J":"ジェイ",
+ "K":"ケイ",
+ "L":"エル",
+ "M":"エム",
+ "N":"エヌ",
+ "O":"オー",
+ "P":"ピー",
+ "Q":"キュー",
+ "R":"アール",
+ "S":"エス",
+ "T":"ティー",
+ "U":"ユー",
+ "V":"ヴィー",
+ "W":"ダブリュー",
+ "X":"エックス",
+ "Y":"ワイ",
+ "Z":"ゼッド"}
+ text = text.upper()
+ text_ret = ""
+ for t in text:
+ if t in alphabet_dict:
+ text_ret += alphabet_dict[t]
+ else:
+ text_ret += t
+ return text_ret
+
+import re
+import cutlet
+
+roma_mapper_plus_2 = {
+
+"bjo":'bʲo',
+"rjo":"rʲo",
+"kjo":"kʲo",
+"kyu":"kʲu",
+
+}
+
+def replace_repeated_chars(input_string):
+ result = []
+ i = 0
+ while i < len(input_string):
+ if i + 1 < len(input_string) and input_string[i] == input_string[i + 1] and input_string[i] in 'aiueo':
+ result.append(input_string[i] + 'ː')
+ i += 2
+ else:
+ result.append(input_string[i])
+ i += 1
+ return ''.join(result)
+
+
+def replace_chars_2(text, mapping=roma_mapper_plus_2):
+
+
+ sorted_keys = sorted(mapping.keys(), key=len, reverse=True)
+
+ pattern = '|'.join(re.escape(key) for key in sorted_keys)
+
+
+ def replace(match):
+ key = match.group(0)
+ return mapping.get(key, key)
+
+ return re.sub(pattern, replace, text)
+
+
+def replace_tashdid_2(s):
+ vowels = 'aiueoɯ0123456789.?!_。؟?!...@@##$$%%^^&&**()()_+=[「」]>\`~~―ー∺"'
+ result = []
+
+ i = 0
+ while i < len(s):
+ if i < len(s) - 2 and s[i].lower() == s[i + 2].lower() and s[i].lower() not in vowels and s[i + 1] == ' ':
+ result.append('ʔ')
+ result.append(s[i + 2])
+ i += 3
+ elif i < len(s) - 1 and s[i].lower() == s[i + 1].lower() and s[i].lower() not in vowels:
+ result.append('ʔ')
+ result.append(s[i + 1])
+ i += 2
+ else:
+ result.append(s[i])
+ i += 1
+
+ return ''.join(result)
+
+def replace_tashdid(input_string):
+ result = []
+ i = 0
+ while i < len(input_string):
+ if i + 1 < len(input_string) and input_string[i] == input_string[i + 1] and input_string[i] not in 'aiueo':
+ result.append('ʔ')
+ result.append(input_string[i])
+ i += 2 # Skip the next character as it is already processed
+ else:
+ result.append(input_string[i])
+ i += 1
+ return ''.join(result)
+
+def hira2ipa(text, roma_mapper=roma_mapper):
+ keys_set = set(roma_mapper.keys())
+ special_rule = ("n", "ɴ")
+
+ transformed_text = []
+ i = 0
+
+ while i < len(text):
+ if text[i] == special_rule[0]:
+ if i + 1 == len(text) or text[i + 1] not in keys_set:
+ transformed_text.append(special_rule[1])
+ else:
+ transformed_text.append(text[i])
+ else:
+ transformed_text.append(text[i])
+
+ i += 1
+
+ return ''.join(transformed_text)
+
+import re
+
+
+k_mapper = dict([
+ ("ゔぁ","ba"),
+ ("ゔぃ","bi"),
+ ("ゔぇ","be"),
+ ("ゔぉ","bo"),
+ ("ゔゃ","bʲa"),
+ ("ゔゅ","bʲɯ"),
+ ("ゔゃ","bʲa"),
+ ("ゔょ","bʲo"),
+
+ ("ゔ","bɯ"),
+
+ ("あぁ"," aː"),
+ ("いぃ"," iː"),
+ ("いぇ"," je"),
+ ("いゃ"," ja"),
+ ("うぅ"," ɯː"),
+ ("えぇ"," eː"),
+ ("おぉ"," oː"),
+ ("かぁ"," kaː"),
+ ("きぃ"," kiː"),
+ ("くぅ","kɯː"),
+ ("くゃ","ka"),
+ ("くゅ","kʲɯ"),
+ ("くょ","kʲo"),
+ ("けぇ","keː"),
+ ("こぉ","koː"),
+ ("がぁ","gaː"),
+ ("ぎぃ","giː"),
+ ("ぐぅ","gɯː"),
+ ("ぐゃ","gʲa"),
+ ("ぐゅ","gʲɯ"),
+ ("ぐょ","gʲo"),
+ ("げぇ","geː"),
+ ("ごぉ","goː"),
+ ("さぁ","saː"),
+ ("しぃ","ɕiː"),
+ ("すぅ","sɯː"),
+ ("すゃ","sʲa"),
+ ("すゅ","sʲɯ"),
+ ("すょ","sʲo"),
+ ("せぇ","seː"),
+ ("そぉ","soː"),
+ ("ざぁ","zaː"),
+ ("じぃ","dʑiː"),
+ ("ずぅ","zɯː"),
+ ("ずゃ","zʲa"),
+ ("ずゅ","zʲɯ"),
+ ("ずょ","zʲo"),
+ ("ぜぇ","zeː"),
+ ("ぞぉ","zeː"),
+ ("たぁ","taː"),
+ ("ちぃ","tɕiː"),
+ ("つぁ","tsa"),
+ ("つぃ","tsi"),
+ ("つぅ","tsɯː"),
+ ("つゃ","tɕa"),
+ ("つゅ","tɕɯ"),
+ ("つょ","tɕo"),
+ ("つぇ","tse"),
+ ("つぉ","tso"),
+ ("てぇ","teː"),
+ ("とぉ","toː"),
+ ("だぁ","daː"),
+ ("ぢぃ","dʑiː"),
+ ("づぅ","dɯː"),
+ ("づゃ","zʲa"),
+ ("づゅ","zʲɯ"),
+ ("づょ","zʲo"),
+ ("でぇ","deː"),
+ ("どぉ","doː"),
+ ("なぁ","naː"),
+ ("にぃ","niː"),
+ ("ぬぅ","nɯː"),
+ ("ぬゃ","nʲa"),
+ ("ぬゅ","nʲɯ"),
+ ("ぬょ","nʲo"),
+ ("ねぇ","neː"),
+ ("のぉ","noː"),
+ ("はぁ","haː"),
+ ("ひぃ","çiː"),
+ ("ふぅ","ɸɯː"),
+ ("ふゃ","ɸʲa"),
+ ("ふゅ","ɸʲɯ"),
+ ("ふょ","ɸʲo"),
+ ("へぇ","heː"),
+ ("ほぉ","hoː"),
+ ("ばぁ","baː"),
+ ("びぃ","biː"),
+ ("ぶぅ","bɯː"),
+ ("ふゃ","ɸʲa"),
+ ("ぶゅ","bʲɯ"),
+ ("ふょ","ɸʲo"),
+ ("べぇ","beː"),
+ ("ぼぉ","boː"),
+ ("ぱぁ","paː"),
+ ("ぴぃ","piː"),
+ ("ぷぅ","pɯː"),
+ ("ぷゃ","pʲa"),
+ ("ぷゅ","pʲɯ"),
+ ("ぷょ","pʲo"),
+ ("ぺぇ","peː"),
+ ("ぽぉ","poː"),
+ ("まぁ","maː"),
+ ("みぃ","miː"),
+ ("むぅ","mɯː"),
+ ("むゃ","mʲa"),
+ ("むゅ","mʲɯ"),
+ ("むょ","mʲo"),
+ ("めぇ","meː"),
+ ("もぉ","moː"),
+ ("やぁ","jaː"),
+ ("ゆぅ","jɯː"),
+ ("ゆゃ","jaː"),
+ ("ゆゅ","jɯː"),
+ ("ゆょ","joː"),
+ ("よぉ","joː"),
+ ("らぁ","ɽaː"),
+ ("りぃ","ɽiː"),
+ ("るぅ","ɽɯː"),
+ ("るゃ","ɽʲa"),
+ ("るゅ","ɽʲɯ"),
+ ("るょ","ɽʲo"),
+ ("れぇ","ɽeː"),
+ ("ろぉ","ɽoː"),
+ ("わぁ","ɯaː"),
+ ("をぉ","oː"),
+
+ ("う゛","bɯ"),
+ ("でぃ","di"),
+ ("でぇ","deː"),
+ ("でゃ","dʲa"),
+ ("でゅ","dʲɯ"),
+ ("でょ","dʲo"),
+ ("てぃ","ti"),
+ ("てぇ","teː"),
+ ("てゃ","tʲa"),
+ ("てゅ","tʲɯ"),
+ ("てょ","tʲo"),
+ ("すぃ","si"),
+ ("ずぁ","zɯa"),
+ ("ずぃ","zi"),
+ ("ずぅ","zɯ"),
+ ("ずゃ","zʲa"),
+ ("ずゅ","zʲɯ"),
+ ("ずょ","zʲo"),
+ ("ずぇ","ze"),
+ ("ずぉ","zo"),
+ ("きゃ","kʲa"),
+ ("きゅ","kʲɯ"),
+ ("きょ","kʲo"),
+ ("しゃ","ɕʲa"),
+ ("しゅ","ɕʲɯ"),
+ ("しぇ","ɕʲe"),
+ ("しょ","ɕʲo"),
+ ("ちゃ","tɕa"),
+ ("ちゅ","tɕɯ"),
+ ("ちぇ","tɕe"),
+ ("ちょ","tɕo"),
+ ("とぅ","tɯ"),
+ ("とゃ","tʲa"),
+ ("とゅ","tʲɯ"),
+ ("とょ","tʲo"),
+ ("どぁ","doa"),
+ ("どぅ","dɯ"),
+ ("どゃ","dʲa"),
+ ("どゅ","dʲɯ"),
+ ("どょ","dʲo"),
+ ("どぉ","doː"),
+ ("にゃ","nʲa"),
+ ("にゅ","nʲɯ"),
+ ("にょ","nʲo"),
+ ("ひゃ","çʲa"),
+ ("ひゅ","çʲɯ"),
+ ("ひょ","çʲo"),
+ ("みゃ","mʲa"),
+ ("みゅ","mʲɯ"),
+ ("みょ","mʲo"),
+ ("りゃ","ɽʲa"),
+ ("りぇ","ɽʲe"),
+ ("りゅ","ɽʲɯ"),
+ ("りょ","ɽʲo"),
+ ("ぎゃ","gʲa"),
+ ("ぎゅ","gʲɯ"),
+ ("ぎょ","gʲo"),
+ ("ぢぇ","dʑe"),
+ ("ぢゃ","dʑa"),
+ ("ぢゅ","dʑɯ"),
+ ("ぢょ","dʑo"),
+ ("じぇ","dʑe"),
+ ("じゃ","dʑa"),
+ ("じゅ","dʑɯ"),
+ ("じょ","dʑo"),
+ ("びゃ","bʲa"),
+ ("びゅ","bʲɯ"),
+ ("びょ","bʲo"),
+ ("ぴゃ","pʲa"),
+ ("ぴゅ","pʲɯ"),
+ ("ぴょ","pʲo"),
+ ("うぁ","ɯa"),
+ ("うぃ","ɯi"),
+ ("うぇ","ɯe"),
+ ("うぉ","ɯo"),
+ ("うゃ","ɯʲa"),
+ ("うゅ","ɯʲɯ"),
+ ("うょ","ɯʲo"),
+ ("ふぁ","ɸa"),
+ ("ふぃ","ɸi"),
+ ("ふぅ","ɸɯ"),
+ ("ふゃ","ɸʲa"),
+ ("ふゅ","ɸʲɯ"),
+ ("ふょ","ɸʲo"),
+ ("ふぇ","ɸe"),
+ ("ふぉ","ɸo"),
+
+ ("あ"," a"),
+ ("い"," i"),
+ ("う","ɯ"),
+ ("え"," e"),
+ ("お"," o"),
+ ("か"," ka"),
+ ("き"," ki"),
+ ("く"," kɯ"),
+ ("け"," ke"),
+ ("こ"," ko"),
+ ("さ"," sa"),
+ ("し"," ɕi"),
+ ("す"," sɯ"),
+ ("せ"," se"),
+ ("そ"," so"),
+ ("た"," ta"),
+ ("ち"," tɕi"),
+ ("つ"," tsɯ"),
+ ("て"," te"),
+ ("と"," to"),
+ ("な"," na"),
+ ("に"," ni"),
+ ("ぬ"," nɯ"),
+ ("ね"," ne"),
+ ("の"," no"),
+ ("は"," ha"),
+ ("ひ"," çi"),
+ ("ふ"," ɸɯ"),
+ ("へ"," he"),
+ ("ほ"," ho"),
+ ("ま"," ma"),
+ ("み"," mi"),
+ ("む"," mɯ"),
+ ("め"," me"),
+ ("も"," mo"),
+ ("ら"," ɽa"),
+ ("り"," ɽi"),
+ ("る"," ɽɯ"),
+ ("れ"," ɽe"),
+ ("ろ"," ɽo"),
+ ("が"," ga"),
+ ("ぎ"," gi"),
+ ("ぐ"," gɯ"),
+ ("げ"," ge"),
+ ("ご"," go"),
+ ("ざ"," za"),
+ ("じ"," dʑi"),
+ ("ず"," zɯ"),
+ ("ぜ"," ze"),
+ ("ぞ"," zo"),
+ ("だ"," da"),
+ ("ぢ"," dʑi"),
+ ("づ"," zɯ"),
+ ("で"," de"),
+ ("ど"," do"),
+ ("ば"," ba"),
+ ("び"," bi"),
+ ("ぶ"," bɯ"),
+ ("べ"," be"),
+ ("ぼ"," bo"),
+ ("ぱ"," pa"),
+ ("ぴ"," pi"),
+ ("ぷ"," pɯ"),
+ ("ぺ"," pe"),
+ ("ぽ"," po"),
+ ("や"," ja"),
+ ("ゆ"," jɯ"),
+ ("よ"," jo"),
+ ("わ"," wa"),
+ ("ゐ"," i"),
+ ("ゑ"," e"),
+ ("ん"," ɴ"),
+ ("っ"," ʔ"),
+ ("ー"," ː"),
+
+ ("ぁ"," a"),
+ ("ぃ"," i"),
+ ("ぅ"," ɯ"),
+ ("ぇ"," e"),
+ ("ぉ"," o"),
+ ("ゎ"," ɯa"),
+ ("ぉ"," o"),
+ ("っ","?"),
+
+ ("を","o")
+
+])
+
+
+def post_fix(text):
+ orig = text
+
+ for k, v in k_mapper.items():
+ text = text.replace(k, v)
+
+ return text
+
+
+
+
+sym_ws = dict([
+
+ ("$ ","dorɯ"),
+ ("$ ","dorɯ"),
+
+ ("〇 ","marɯ"),
+ ("¥ ","eɴ"),
+
+ ("# ","haʔɕɯ tagɯ"),
+ ("# ","haʔɕɯ tagɯ"),
+
+ ("& ","ando"),
+ ("& ","ando"),
+
+ ("% ","paːsento"),
+ ("% ","paːsento"),
+
+ ("@ ","aʔto saiɴ"),
+ ("@ ","aʔto saiɴ")
+
+
+
+])
+
+def random_sym_fix(text): # with space
+ orig = text
+
+ for k, v in sym_ws.items():
+ text = text.replace(k, f" {v} ")
+
+ return text
+
+
+sym_ns = dict([
+
+ ("$","dorɯ"),
+ ("$","dorɯ"),
+
+ ("〇","marɯ"),
+ ("¥","eɴ"),
+
+ ("#","haʔɕɯ tagɯ"),
+ ("#","haʔɕɯ tagɯ"),
+
+ ("&","ando"),
+ ("&","ando"),
+
+ ("%","paːsento"),
+ ("%","paːsento"),
+
+ ("@","aʔto saiɴ"),
+ ("@","aʔto saiɴ"),
+
+ ("~","—"),
+ ("kʲɯɯdʑɯɯkʲɯɯ.kʲɯɯdʑɯɯ","kʲɯɯdʑɯɯ kʲɯɯ teɴ kʲɯɯdʑɯɯ")
+
+
+
+
+
+])
+
+def random_sym_fix_no_space(text):
+ orig = text
+
+ for k, v in sym_ns.items():
+ text = text.replace(k, f" {v} ")
+
+ return text
+
+
+spaces = dict([
+
+ ("ɯ ɴ","ɯɴ"),
+ ("na ɴ ","naɴ "),
+ (" mina ", " miɴna "),
+ ("ko ɴ ni tɕi ha","konnitɕiwa"),
+ ("ha i","hai"),
+ ("boɯtɕama","boʔtɕama"),
+ ("i eːi","ieːi"),
+ ("taiɕɯtsɯdʑoɯ","taiɕitsɯdʑoɯ"),
+ ("soɴna ka ze ni","soɴna fɯɯ ni"),
+ (" i e ","ke "),
+ ("�",""),
+ ("×"," batsɯ "),
+ ("se ka ɯndo","sekaɯndo"),
+ ("i i","iː"),
+ ("i tɕi","itɕi"),
+ ("ka i","kai"),
+ ("naɴ ga","nani ga"),
+ ("i eː i","ieːi"),
+
+ ("naɴ koɽe","nani koɽe"),
+ ("naɴ soɽe","nani soɽe"),
+ (" ɕeɴ "," seɴ "),
+
+ # ("konna","koɴna"),
+ # ("sonna"," soɴna "),
+ # ("anna","aɴna"),
+ # ("nn","ɴn"),
+
+ ("en ","eɴ "),
+ ("in ","iɴ "),
+ ("an ","aɴ "),
+ ("on ","oɴ "),
+ ("ɯn ","ɯɴ "),
+ # ("nd","ɴd"),
+
+ ("koɴd o","kondo"),
+ ("ko ɴ d o","kondo"),
+ ("ko ɴ do","kondo"),
+
+ ("oanitɕaɴ","oniːtɕaɴ"),
+ ("oanisaɴ","oniːsaɴ"),
+ ("oanisama","oniːsama"),
+ ("hoːmɯrɯɴɯ","hoːmɯrɯːmɯ"),
+ ("so ɴ na ","sonna"),
+ (" sonna "," sonna "),
+ (" konna "," konna "),
+ ("ko ɴ na ","konna"),
+ (" ko to "," koto "),
+ ("edʑdʑi","eʔtɕi"),
+ (" edʑdʑ "," eʔtɕi "),
+ (" dʑdʑ "," dʑiːdʑiː "),
+ ("secɯnd","sekaɯndo"),
+
+ ("ɴɯ","nɯ"),
+ ("ɴe","ne"),
+ ("ɴo","no"),
+ ("ɴa","na"),
+ ("ɴi","ni"),
+ ("ɴʲ","nʲ"),
+
+ ("hotond o","hotondo"),
+ ("hakoɴd e","hakoɴde"),
+ ("gakɯtɕi ɽi","gaʔtɕiɽi "),
+
+ (" ʔ","ʔ"),
+ ("ʔ ","ʔ"),
+
+ ("-","ː"),
+ ("- ","ː"),
+ ("--","~ː"),
+ ("~","—"),
+ ("、",","),
+
+
+ (" ː","ː"),
+ ('ka nade',"kanade"),
+
+ ("ohahasaɴ","okaːsaɴ"),
+ (" "," "),
+ ("viː","bɯiː"),
+ ("ːː","ː—"),
+
+ ("d ʑ","dʑ"),
+ ("d a","da"),
+ ("d e","de"),
+ ("d o","do"),
+ ("d ɯ","dɯ"),
+
+ ("niːɕiki","ni iɕiki"),
+ ("anitɕaɴ","niːtɕaɴ"),
+ ("daiːtɕi","dai itɕi"),
+ ("niːta","ni ita"),
+ ("niːrɯ","ni irɯ"),
+ ("a—","aː"),
+ ("waːis","wa ais"),
+ ("waːiɕ","wa aiɕ"),
+ ("aːt","a at"),
+ ("waːʔ", "wa aʔ"),
+
+ ("naɴ sono","nani sono"),
+ ("naɴ kono","nani kono"),
+ ("naɴ ano","nani ano"), # Cutlet please fix your shit
+ (" niːtaɽa"," ni itaɽa"),
+ ("doɽamaɕiːd","doɽama ɕiːdʲi"),
+ ("aɴ ta","anta"),
+ ("aɴta","anta"),
+ ("naniːʔteɴ","nani iʔteɴ"),
+ ("niːkite","ni ikite"),
+ ("niːʔ","ni iʔ"),
+ ("niːɯ","ni iɯ"),
+ ("niːw","ni iw"),
+ ("niːkɯ","ni ikɯ"),
+ ("de—","de e"),
+ ("aːj","aː aj"),
+ ("aːɽ","a aɽ"),
+ ("aːr","a ar"),
+ ("ɕiːk ","ɕi ik"),
+ ("ɕijoː neɴ","ɕoɯneɴ")
+
+
+])
+
+
+
+def random_space_fix(text):
+ orig = text
+
+ for k, v in spaces.items():
+ text = text.replace(k, v)
+
+ return text
\ No newline at end of file
diff --git a/Utils/phonemize/mixed_phon.py b/Utils/phonemize/mixed_phon.py
new file mode 100644
index 0000000000000000000000000000000000000000..6af4fc7e75992868b05987e86d236560546ffecc
--- /dev/null
+++ b/Utils/phonemize/mixed_phon.py
@@ -0,0 +1,55 @@
+import re
+from Utils.phonemize.cotlet_phon import phonemize
+from Utils.phonemize.cotlet_phon_dir_backend import latn_phonemize
+
+# make sure you have correct spacing when using a mixture of japanese and romaji otherwise it goes into alphabet reading mode.
+
+def is_japanese(text):
+
+ japanese_ranges = [
+ (0x3040, 0x309F), # Hiragana
+ (0x30A0, 0x30FF), # Katakana
+ (0x4E00, 0x9FFF), # Kanji
+ ]
+
+ for char in text:
+ char_code = ord(char)
+ for start, end in japanese_ranges:
+ if start <= char_code <= end:
+ return True
+ return False
+
+def has_only_japanese(text):
+ # Remove spaces and check if all remaining characters are Japanese
+ text_no_spaces = ''.join(char for char in text if not char.isspace())
+ return all(is_japanese(char) for char in text_no_spaces)
+
+def has_only_romaji(text):
+ # Remove spaces and check if all remaining characters are ASCII
+ text_no_spaces = ''.join(char for char in text if not char.isspace())
+ return all(ord(char) < 128 for char in text_no_spaces)
+
+def mixed_phonemize(text):
+ # Split text into words while preserving spaces
+ words = re.findall(r'\S+|\s+', text)
+ result = []
+
+ for word in words:
+ if word.isspace():
+ result.append(word)
+ continue
+
+ if is_japanese(word):
+ result.append(phonemize(word))
+ else:
+ result.append(latn_phonemize(word))
+
+ return ''.join(result)
+
+def smart_phonemize(text):
+ if has_only_japanese(text):
+ return phonemize(text)
+ elif has_only_romaji(text):
+ return latn_phonemize(text)
+ else:
+ return mixed_phonemize(text)
\ No newline at end of file
diff --git a/__pycache__/cotlet_phon.cpython-311.pyc b/__pycache__/cotlet_phon.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..072f837d2e90bdaf001e9e88b4b11e53e4c5ca83
Binary files /dev/null and b/__pycache__/cotlet_phon.cpython-311.pyc differ
diff --git a/__pycache__/cotlet_phon_dir_backend.cpython-311.pyc b/__pycache__/cotlet_phon_dir_backend.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..953bdd5a76d5ca6f4f158c94f912892f5dbea7aa
Binary files /dev/null and b/__pycache__/cotlet_phon_dir_backend.cpython-311.pyc differ
diff --git a/__pycache__/cotlet_utils.cpython-311.pyc b/__pycache__/cotlet_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1735e43cfeffbf471b70dd17b087b7e899d14707
Binary files /dev/null and b/__pycache__/cotlet_utils.cpython-311.pyc differ
diff --git a/__pycache__/importable.cpython-311.pyc b/__pycache__/importable.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5728269147404685a2c48652ddad1dd3ee6ce49
Binary files /dev/null and b/__pycache__/importable.cpython-311.pyc differ
diff --git a/__pycache__/models.cpython-311.pyc b/__pycache__/models.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20e18bcda3a57b88ebf23bbf27d7c00928fe8b83
Binary files /dev/null and b/__pycache__/models.cpython-311.pyc differ
diff --git a/__pycache__/text_utils.cpython-311.pyc b/__pycache__/text_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..938b41bf361bfc5e8cd89ecc9c40ee9639956fc8
Binary files /dev/null and b/__pycache__/text_utils.cpython-311.pyc differ
diff --git a/__pycache__/utils.cpython-311.pyc b/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5752589993e469e65ea8ccd68d15b8e62c612936
Binary files /dev/null and b/__pycache__/utils.cpython-311.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea546f9a95c9f1f2df06e39162750a2b1395d13f
--- /dev/null
+++ b/app.py
@@ -0,0 +1,561 @@
+INTROTXT = """#
+Repo -> [Hugging Face - 🤗](https://huggingface.co/Respair/Project_Kanade_SpeechModel)
+This space uses Tsukasa (24khz).
+**Check the Read me tabs down below.**
+Enjoy!
+"""
+import gradio as gr
+import random
+import importable
+import torch
+import os
+from cotlet_phon import phonemize
+import numpy as np
+import pickle
+
+
+voices = {}
+example_texts = {}
+prompts = []
+inputs = []
+
+
+theme = gr.themes.Base(
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
+)
+
+from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
+
+voicelist = [v for v in os.listdir("/home/ubuntu/Kanade_Project/gradio/Tsukasa_Speech/reference_sample_wavs")]
+
+
+
+for v in voicelist:
+ voices[v] = importable.compute_style_through_clip(f'reference_sample_wavs/{v}')
+
+
+with open(f'Inference/random_texts.txt', 'r') as r:
+ random_texts = [line.strip() for line in r]
+
+ example_texts = {f"{text[:30]}...": text for text in random_texts}
+
+def update_text_input(preview):
+
+ return example_texts[preview]
+
+def get_random_text():
+ return random.choice(random_texts)
+
+
+
+with open('Inference/prompt.txt', 'r') as p:
+ prompts = [line.strip() for line in p]
+
+with open('Inference/input_for_prompt.txt', 'r') as i:
+ inputs = [line.strip() for line in i]
+
+
+last_idx = None
+
+def get_random_prompt_pair():
+ global last_idx
+ max_idx = min(len(prompts), len(inputs)) - 1
+
+
+ random_idx = random.randint(0, max_idx)
+ while random_idx == last_idx:
+ random_idx = random.randint(0, max_idx)
+
+ last_idx = random_idx
+ return inputs[random_idx], prompts[random_idx]
+
+def Synthesize_Audio(text, voice, voice2, vcsteps, embscale, alpha, beta, ros, progress=gr.Progress()):
+
+
+ text = phonemize(text)
+
+
+ if voice2:
+ voice_style = importable.compute_style_through_clip(voice2)
+ else:
+ voice_style = voices[voice]
+
+ wav = importable.inference(
+ text,
+ voice_style,
+ alpha=alpha,
+ beta=beta,
+ diffusion_steps=vcsteps,
+ embedding_scale=embscale,
+ rate_of_speech=ros
+ )
+
+ return (24000, wav)
+
+
+def LongformSynth_Text(text, s_prev, Kotodama, alpha, beta, t, diffusion_steps, embedding_scale, rate_of_speech , progress=gr.Progress()):
+
+ japanese = text
+
+ # raw_jpn = japanese[japanese.find(":") + 2:]
+ # speaker = japanese[:japanese.find(":") + 2]
+
+
+ if ":" in japanese[:10]:
+ raw_jpn = japanese[japanese.find(":") + 2:]
+ speaker = japanese[:japanese.find(":") + 2]
+ else:
+ raw_jpn = japanese
+ speaker = ""
+
+ sentences = importable.sent_tokenizer.tokenize(raw_jpn)
+ sentences = importable.merging_sentences(sentences)
+
+ silence = 24000 * 0.5 # 500 ms of silence between outputs for a more natural transition
+ # sentences = sent_tokenize(text)
+ print(sentences)
+ wavs = []
+ s_prev = None
+ for text in sentences:
+
+ text_input = phonemize(text)
+ print('phonemes -> ', text_input)
+
+ Kotodama = importable.Kotodama_Sampler(importable.model, text=speaker + text, device=importable.device)
+
+ wav, s_prev = importable.Longform(text_input,
+ s_prev,
+ Kotodama,
+ alpha = alpha,
+ beta = beta,
+ t = t,
+ diffusion_steps=diffusion_steps, embedding_scale=embedding_scale, rate_of_speech=rate_of_speech)
+ wavs.append(wav)
+ wavs.append(np.zeros(int(silence)))
+
+ print('Synthesized: ')
+ return (24000, np.concatenate(wavs))
+
+
+
+
+def Inference_Synth_Prompt(text, description, Kotodama, alpha, beta, diffusion_steps, embedding_scale, rate_of_speech , progress=gr.Progress()):
+
+
+
+ prompt = f"""{description} \n text: {text}"""
+
+ print('prompt ->: ', prompt)
+
+ text = phonemize(text)
+
+ print('phonemes ->: ', text)
+
+ Kotodama = importable.Kotodama_Prompter(importable.model, text=prompt, device=importable.device)
+
+ wav = importable.inference(text,
+ Kotodama,
+ alpha = alpha,
+ beta = beta,
+ diffusion_steps=diffusion_steps, embedding_scale=embedding_scale, rate_of_speech=rate_of_speech)
+
+ wav = importable.trim_long_silences(wav)
+
+
+ print('Synthesized: ')
+ return (24000, wav)
+
+with gr.Blocks() as audio_inf:
+ with gr.Row():
+ with gr.Column(scale=1):
+ inp = gr.Textbox(label="Text", info="Enter the text", value="きみの存在は、私の心の中で燃える小さな光のよう。きみがいない時、世界は白黒の写真みたいに寂しくて、何も輝いてない。きみの笑顔だけが、私の灰色の日々に色を塗ってくれる。離れてる時間は、めちゃくちゃ長く感じられて、きみへの想いは風船みたいにどんどん膨らんでいく。きみなしの世界なんて、想像できないよ。", interactive=True, scale=5)
+ voice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value=voicelist[-1], interactive=True)
+ voice_2 = gr.Audio(label="Upload your own Audio", interactive=True, type='filepath', max_length=300, waveform_options={'waveform_color': '#a3ffc3', 'waveform_progress_color': '#e972ab'})
+
+ with gr.Accordion("Advanced Parameters", open=False):
+
+ alpha = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Alpha", info="a Diffusion sampler parameter handling the timbre, higher means less affected by the reference | 0 = diffusion is disabled", interactive=True)
+ beta = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Beta", info="a Diffusion sampler parameter, higher means less affected by the reference | 0 = diffusion is disabled", interactive=True)
+ multispeakersteps = gr.Slider(minimum=3, maximum=15, value=5, step=1, label="Diffusion Steps", interactive=True)
+ embscale = gr.Slider(minimum=1, maximum=5, value=1, step=0.1, label="Intensity", info="will impact the expressiveness, if you raise it too much it'll break.", interactive=True)
+ rate_of_speech = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.1, label="Rate of Speech", info="Higher -> Faster", interactive=True)
+
+ with gr.Column(scale=1):
+ btn = gr.Button("Synthesize", variant="primary")
+ audio = gr.Audio(interactive=False, label="Synthesized Audio", waveform_options={'waveform_color': '#a3ffc3', 'waveform_progress_color': '#e972ab'})
+ btn.click(Synthesize_Audio, inputs=[inp, voice, voice_2, multispeakersteps, embscale, alpha, beta, rate_of_speech], outputs=[audio], concurrency_limit=4)
+
+# Kotodama Text sampler Synthesis Block
+with gr.Blocks() as longform:
+ with gr.Row():
+ with gr.Column(scale=1):
+ inp_longform = gr.Textbox(
+ label="Text",
+ info="Enter the text [Speaker: Text] | Also works without any name.",
+ value=list(example_texts.values())[0],
+ interactive=True,
+ scale=5
+ )
+
+ with gr.Row():
+ example_dropdown = gr.Dropdown(
+ choices=list(example_texts.keys()),
+ label="Example Texts [pick one!]",
+ value=list(example_texts.keys())[0],
+ interactive=True
+ )
+
+ example_dropdown.change(
+ fn=update_text_input,
+ inputs=[example_dropdown],
+ outputs=[inp_longform]
+ )
+
+ with gr.Accordion("Advanced Parameters", open=False):
+
+ alpha_longform = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1,
+ label="Alpha",
+ info="a Diffusion parameter handling the timbre, higher means less affected by the reference | 0 = diffusion is disabled",
+ interactive=True)
+ beta_longform = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1,
+ label="Beta",
+ info="a Diffusion parameter, higher means less affected by the reference | 0 = diffusion is disabled",
+ interactive=True)
+ diffusion_steps_longform = gr.Slider(minimum=3, maximum=15, value=10, step=1,
+ label="Diffusion Steps",
+ interactive=True)
+ embedding_scale_longform = gr.Slider(minimum=1, maximum=5, value=1.25, step=0.1,
+ label="Intensity",
+ info="a Diffusion parameter, it will impact the expressiveness, if you raise it too much it'll break.",
+ interactive=True)
+
+ rate_of_speech_longform = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.1,
+ label="Rate of Speech",
+ info="Higher = Faster",
+ interactive=True)
+
+ with gr.Column(scale=1):
+ btn_longform = gr.Button("Synthesize", variant="primary")
+ audio_longform = gr.Audio(interactive=False,
+ label="Synthesized Audio",
+ waveform_options={'waveform_color': '#a3ffc3', 'waveform_progress_color': '#e972ab'})
+
+ btn_longform.click(LongformSynth_Text,
+ inputs=[inp_longform,
+ gr.State(None), # s_prev
+ gr.State(None), # Kotodama
+ alpha_longform,
+ beta_longform,
+ gr.State(.8), # t parameter
+ diffusion_steps_longform,
+ embedding_scale_longform,
+ rate_of_speech_longform],
+ outputs=[audio_longform],
+ concurrency_limit=4)
+
+# Kotodama prompt sampler Inference Block
+with gr.Blocks() as prompt_inference:
+ with gr.Row():
+ with gr.Column(scale=1):
+ text_prompt = gr.Textbox(
+ label="Text",
+ info="Enter the text to synthesize. This text will also be fed to the encoder. Make sure to see the Read Me for more details!",
+ value=inputs[0],
+ interactive=True,
+ scale=5
+ )
+ description_prompt = gr.Textbox(
+ label="Description",
+ info="Enter a highly detailed, descriptive prompt that matches the vibe of your text to guide the synthesis.",
+ value=prompts[0],
+ interactive=True,
+ scale=7
+ )
+
+ with gr.Row():
+ random_btn = gr.Button('Random Example', variant='secondary')
+
+ with gr.Accordion("Advanced Parameters", open=True):
+ embedding_scale_prompt = gr.Slider(minimum=1, maximum=5, value=1, step=0.25,
+ label="Intensity",
+ info="it will impact the expressiveness, if you raise it too much it'll break.",
+ interactive=True)
+ alpha_prompt = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1,
+ label="Alpha",
+ info="a Diffusion sampler parameter handling the timbre, higher means less affected by the reference | 0 = diffusion is disabled",
+ interactive=True)
+ beta_prompt = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1,
+ label="Beta",
+ info="a Diffusion sampler parameter, higher means less affected by the reference | 0 = diffusion is disabled",
+ interactive=True)
+ diffusion_steps_prompt = gr.Slider(minimum=3, maximum=15, value=10, step=1,
+ label="Diffusion Steps",
+ interactive=True)
+ rate_of_speech_prompt = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.1,
+ label="Rate of Speech",
+ info="Higher = Faster",
+ interactive=True)
+ with gr.Column(scale=1):
+ btn_prompt = gr.Button("Synthesize with Prompt", variant="primary")
+ audio_prompt = gr.Audio(interactive=False,
+ label="Prompt-based Synthesized Audio",
+ waveform_options={'waveform_color': '#a3ffc3', 'waveform_progress_color': '#e972ab'})
+
+
+ random_btn.click(
+ fn=get_random_prompt_pair,
+ inputs=[],
+ outputs=[text_prompt, description_prompt]
+ )
+
+ btn_prompt.click(Inference_Synth_Prompt,
+ inputs=[text_prompt,
+ description_prompt,
+ gr.State(None),
+ alpha_prompt,
+ beta_prompt,
+ diffusion_steps_prompt,
+ embedding_scale_prompt,
+ rate_of_speech_prompt],
+ outputs=[audio_prompt],
+ concurrency_limit=4)
+
+notes = """
+Notes
+
+
+This work is somewhat different from your typical speech model. It offers a high degree of control
+over the generation process, which means it's easy to inadvertently produce unimpressive outputs.
+
+
+
+Kotodama and the Diffusion sampler can significantly help guide the generation towards
+something that aligns with your input, but they aren't foolproof.
+
+
+
+The model's peak performance is achieved when the Diffusion sampler and Kotodama work seamlessly together.
+However, we won't see that level of performance here because this checkpoint is somewhat undertrained
+due to my time and resource constraints. (Tsumugi should be better in this regard,
+albeit if the diffusion works at all on your hardware.)
+Hopefully, you can further fine-tune this model (or train from scratch) to achieve even better results!
+
+
+
+The prompt encoder is also highly experimental and should be treated as a proof of concept. Due to the
+overwhelming ratio of female to male speakers and the wide variation in both speakers and their expressions,
+the prompt encoder may occasionally produce subpar or contradicting outputs. For example, high expressiveness alongside
+high pitch has been associated with females speakers simply because I had orders of magnitude more of them in the dataset.
+
+
+
+________________________________________________________
+A useful note about the voice design and prompting:
\n
+The vibe of the dialogue impacts the generated voice since the Japanese dialogue
+and the prompts were jointly trained. This is a peculiar feature of the Japanese lanuage.
+For example if you use 俺 (ore)、僕(boku) or your input is overall masculine
+you may get a guy's voice, even if you describe it as female in the prompt.
\n
+The Japanese text that is fed to the prompt doesn't necessarily have to be
+the same as your input, but we can't do it in this demo
+to not make the page too convoluted. In a real world scenario, you can just use a
+prompt with a suitable Japanese text to guide the model, get the style
+then move on to apply it to whatever dialogue you wish your model to speak.
+
+
+
+________________________________________________________
+
+The pitch information in my data was accurately calculated, but it only works in comparison to the other speakers
+so you may find a deep pitch may not be exactly too deep; although it actually is
+when you compare it to others within the same data, also some of the gender labels
+are inaccurate since we used a model to annotate them.
\n
+The main goal of this inference method is to demonstrate that style can be mapped to description's embeddings
+yielding reasonably good results.
+
+
+
+Overall, I'm confident that with a bit of experimentation, you can achieve reasonbaly good results.
+The model should work well out of the box 90% of the time without the need for extensive tweaking.
+However, here are some tips in case you encounter issues:
+
+
+Tips:
+
+
+ -
+ Ensure that your input closely matches your reference (audio or text prompt) in terms of tone,
+ non-verbal cues, duration, etc.
+
+
+ -
+ If your audio is too long but the input is too short, the speech rate will be slow, and vice versa.
+
+
+ -
+ Experiment with the alpha, beta, and Intensity parameters. The Diffusion
+ sampler is non-deterministic, so regenerate a few times if you're not satisfied with the output.
+
+
+ -
+ The speaker's share and expressive distribution in the dataset significantly impact the quality;
+ you won't necessarily get perfect results with all speakers.
+
+
+ -
+ Punctuation is very important, for example if you add «!» mark it will raise the voice or make it more intense.
+
+
+ -
+ Not all speakers are equal. Less represented speakers or out-of-distribution inputs may result
+ in artifacts.
+
+
+ -
+ If the Diffusion sampler works but the speaker didn't have a certain expression (e.g., extreme anger)
+ in the dataset, try raising the diffusion sampler's parameters and let it handle everything. Though
+ it may result in less speaker similarity, the ideal way to handle this is to cook new vectors by
+ transferring an emotion from one speaker to another. But you can't do that in this space.
+
+
+ -
+ For voice-based inference, you can use litagin's awesome Moe-speech dataset,
+ as part of the training data includes a portion of that.
+
+
+ -
+ you may also want to tweak the phonemes if you're going for something wild.
+ i have used cutlet in the backend, but that doesn't seem to like some of my mappings.
+
+
+
+
+"""
+
+
+notes_jp = """
+メモ
+
+
+この作業は、典型的なスピーチモデルとは少し異なります。生成プロセスに対して高い制御を提供するため、意図せずに
+比較的にクオリティーの低い出力を生成してしまうことが容易です。
+
+
+
+KotodamaとDiffusionサンプラーは、入力に沿ったものを生成するための大きな助けとなりますが、
+万全というわけではありません。
+
+
+
+モデルの最高性能は、DiffusionサンプラーとKotodamaがシームレスに連携することで達成されます。しかし、
+このチェックポイントは時間とリソースの制約からややTrain不足であるため、そのレベルの性能はここでは見られません。
+(この件について、「紬」のチェックポイントの方がいいかもしれません。でもまぁ、みなさんのハードに互換性があればね。)
+おそらく、このモデルをさらにFinetuningする(または最初からTrainする)ことで、より良い結果が得られるでしょう。
+
+
+_____________________________________________
\n
+音声デザインとプロンプトに関する有用なメモ:
+ダイアログの雰囲気は、日本語のダイアログとプロンプトが共同でTrainされたため、生成される音声に影響を与えます。
+これは日本語の特徴的な機能です。例えば、「俺」や「僕」を使用したり、全体的に男性らしい入力をすると、
+プロンプトで女性と記述していても、男性の声が得られる可能性があります。
+プロンプトに入力される日本語のテキストは、必ずしも入力内容と同じである必要はありませんが、
+このデモではページが複雑になりすぎないようにそれを行うことはできません。
+実際のシナリオでは、適切な日本語のテキストを含むプロンプトを使用してモデルを導き、
+スタイルを取得した後、それを希望するダイアログに適用することができます。
+
+_____________________________________________
\n
+
+
+プロンプトエンコーダも非常に実験的であり、概念実証として扱うべきです。女性話者対男性話者の比率が圧倒的で、
+また話者とその表現に大きなバリエーションがあるため、エンコーダは質の低い出力を生成する可能性があります。
+例えば、高い表現力は、データセットに多く含まれていた女性話者と関連付けられています。
+それに、データのピッチ情報は正確に計算されましたが、それは他のスピーカーとの比較でしか機能しません...
+だから、深いピッチが必ずしも深すぎるわけではないことに気づくかもしれません。
+ただし、実際には、同じデータ内の他の人と比較すると、深すぎます。このインフレンスの主な目的は、
+スタイルベクトルを記述にマッピングし、合理的に良い結果を得ることにあります。
+
+
+
+全体として、少しの実験でほぼ望む結果を達成できると自信を持っています。90%のケースで、大幅な調整を必要とせず、
+そのままでうまく動作するはずです。しかし、問題が発生した場合のためにいくつかのヒントがあります:
+
+
+ヒント:
+
+
+ -
+ 入力がリファレンス(音声またはテキストプロンプト)とトーン、非言語的な手がかり、
+ 長さなどで密接に一致していることを確認してください。
+
+
+ -
+ 音声が長すぎるが入力が短すぎる場合、話速が遅くなります。その逆もまた同様です。
+
+
+ -
+ アルファ、ベータ、および埋め込みスケールのパラメータを試行錯誤してください。Diffusionサンプラーは
+ 非決定的なので、満足のいく出力が得られない場合は何度か再生成してください。
+
+
+ -
+ データセット内の話者の分布と表現力の分布は品質に大きく影響します。
+ すべての話者で必ずしも完璧な結果が得られるわけではありません。
+
+
+ -
+ 句読点は重要です。たとえな、「!」を使えば、スタイルのインテンシティが上がります。
+
+
+ -
+ すべての話者が平等に表現されているわけではありません。少ない表現の話者や
+ 分布外の入力はアーティファクトを生じさせる可能性があります。
+
+
+ -
+ Diffusionサンプラーが機能しているが、データセット内で特定の表現(例:極度の怒り)がない場合、
+ Diffusionサンプラーのパラメータを引き上げ、サンプラーにすべてを任せてください。ただし、それにより
+ 話者の類似性が低下する可能性があります。この問題を理想的に解決する方法は、ある話者から別の話者に
+ 感情を転送し新しいベクトルを作成することですが、ここではできません。
+
+
+ -
+ 音声ベースのインフレンスには、トレーニングデータの一部としてMoe-speechデータセットの一部を含む
+ litaginの素晴らしいデータセットを使用できます。
+
+
+ -
+ たまには音素の調整が必要になる場合もあります。バックエンドではcutletを使っているのですが、
+ いくつかのOODマッピングがcutletと相性が良くないみたいです。
+
+
+
+"""
+with gr.Blocks() as read_me:
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown(notes)
+
+with gr.Blocks() as read_me_jp:
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown(notes_jp)
+
+
+custom_css = """
+.tab-label {
+ color: #FFD700 !important;
+}
+"""
+
+
+
+
+with gr.Blocks(title="Tsukasa 司", css=custom_css + "footer{display:none !important}", theme="Respair/Shiki@1.2.2") as demo:
+ # gr.DuplicateButton("Duplicate Space")
+ gr.Markdown(INTROTXT)
+
+
+ gr.TabbedInterface([longform, audio_inf, prompt_inference, read_me, read_me_jp],
+ ['Kotodama Text Inference', 'Voice-guided Inference','Prompt-guided Inference [Highly Experimental - not optimized]', 'Read Me! [English]', 'Read Me! [日本語]'])
+
+if __name__ == "__main__":
+ demo.queue(api_open=False, max_size=15).launch(show_api=False, share=True)
diff --git a/app_tsumugi.py b/app_tsumugi.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec95d7de6b3dfdb264c03f7e238cacd5b03dffb8
--- /dev/null
+++ b/app_tsumugi.py
@@ -0,0 +1,645 @@
+INTROTXT = """#
+Repo -> [Hugging Face - 🤗](https://huggingface.co/Respair/Project_Kanade_SpeechModel)
+This space uses Tsukasa (24khz).
+**Check the Read me tabs down below.**
+Enjoy!
+"""
+import gradio as gr
+import random
+import importable
+import torch
+import os
+from Utils.phonemize.mixed_phon import smart_phonemize
+import numpy as np
+import pickle
+import re
+
+def is_japanese(text):
+ if not text: # Handle empty string
+ return False
+
+ # Define ranges for Japanese characters
+ japanese_ranges = [
+ (0x3040, 0x309F), # Hiragana
+ (0x30A0, 0x30FF), # Katakana
+ (0x4E00, 0x9FFF), # Kanji
+ (0x3000, 0x303F), # Japanese punctuation and symbols
+ (0xFF00, 0xFFEF), # Full-width characters
+ ]
+
+ # Define range for Latin alphabets
+ latin_alphabet_ranges = [
+ (0x0041, 0x005A), # Uppercase Latin
+ (0x0061, 0x007A), # Lowercase Latin
+ ]
+
+ # Define symbols to skip
+ symbols_to_skip = {'\'', '*', '!', '?', ',', '.', ':', ';', '-', '_', '(', ')', '[', ']', '{', '}', '"'}
+
+ for char in text:
+ if char.isspace() or char in symbols_to_skip: # Skip spaces and specified symbols
+ continue
+
+ char_code = ord(char)
+
+ # Check if the character is a Latin alphabet
+ is_latin_char = False
+ for start, end in latin_alphabet_ranges:
+ if start <= char_code <= end:
+ is_latin_char = True
+ break
+
+ if is_latin_char:
+ return False # Return False if a Latin alphabet character is found
+
+ # Check if the character is a Japanese character
+ is_japanese_char = False
+ for start, end in japanese_ranges:
+ if start <= char_code <= end:
+ is_japanese_char = True
+ break
+
+ if not is_japanese_char:
+ return False
+
+ return True
+
+
+
+voices = {}
+example_texts = {}
+prompts = []
+inputs = []
+
+
+theme = gr.themes.Base(
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
+)
+
+from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
+
+voicelist = [v for v in os.listdir("/home/ubuntu/Kanade_Project/gradio/Tsukasa_Speech/reference_sample_wavs")]
+
+
+
+for v in voicelist:
+ voices[v] = f'reference_sample_wavs/{v}'
+
+
+with open(f'Inference/random_texts.txt', 'r') as r:
+ random_texts = [line.strip() for line in r]
+
+ example_texts = {f"{text[:30]}...": text for text in random_texts}
+
+def update_text_input(preview):
+
+ return example_texts[preview]
+
+def get_random_text():
+ return random.choice(random_texts)
+
+
+
+with open('Inference/prompt.txt', 'r') as p:
+ prompts = [line.strip() for line in p]
+
+with open('Inference/input_for_prompt.txt', 'r') as i:
+ inputs = [line.strip() for line in i]
+
+
+last_idx = None
+
+def get_random_prompt_pair():
+ global last_idx
+ max_idx = min(len(prompts), len(inputs)) - 1
+
+
+ random_idx = random.randint(0, max_idx)
+ while random_idx == last_idx:
+ random_idx = random.randint(0, max_idx)
+
+ last_idx = random_idx
+ return inputs[random_idx], prompts[random_idx]
+
+def Synthesize_Audio(text, voice, voice2, vcsteps, embscale, alpha, beta, ros, progress=gr.Progress()):
+
+
+ text = smart_phonemize(text)
+
+
+ if voice2 is not None:
+ voice2 = {"path": voice2, "meta": {"_type": "gradio.FileData"}}
+ print(voice2)
+ voice_style = importable.compute_style_through_clip(voice2['path'])
+
+ else:
+ voice_style = importable.compute_style_through_clip(voices[voice])
+
+ wav = importable.inference(
+ text,
+ voice_style,
+ alpha=alpha,
+ beta=beta,
+ diffusion_steps=vcsteps,
+ embedding_scale=embscale,
+ rate_of_speech=ros
+ )
+
+ return (24000, wav)
+
+
+def LongformSynth_Text(text, s_prev=None, Kotodama=None, alpha=.0, beta=0, t=.8, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.):
+
+ japanese = text
+
+ # raw_jpn = japanese[japanese.find(":") + 2:]
+ # speaker = japanese[:japanese.find(":") + 2]
+
+
+ if ":" in japanese[:10]:
+ raw_jpn = japanese[japanese.find(":") + 2:]
+ speaker = japanese[:japanese.find(":") + 2]
+ else:
+ raw_jpn = japanese
+ speaker = ""
+
+ sentences = importable.sent_tokenizer.tokenize(raw_jpn)
+ sentences = importable.merging_sentences(sentences)
+
+
+ # if is_japanese(raw_jpn):
+ # kotodama_prompt = kotodama_prompt
+
+
+ # else:
+ # kotodama_prompt = speaker + importable.p2g(smart_phonemize(raw_jpn))
+ # print('kimia activated! the converted text is: ', kotodama_prompt)
+
+
+
+ silence = 24000 * 0.5 # 500 ms of silence between outputs for a more natural transition
+ # sentences = sent_tokenize(text)
+ print(sentences)
+ wavs = []
+ s_prev = None
+ for text in sentences:
+
+ text_input = smart_phonemize(text)
+ print('phonemes -> ', text_input)
+
+ if is_japanese(text):
+ kotodama_prompt = text
+
+
+ else:
+ kotodama_prompt = importable.p2g(smart_phonemize(text))
+ kotodama_prompt = re.sub(r'\s+', ' ', kotodama_prompt).strip()
+ print('kimia activated! the converted text is:\n ', kotodama_prompt)
+
+
+
+ Kotodama = importable.Kotodama_Sampler(importable.model, text=speaker + kotodama_prompt, device=importable.device)
+
+ wav, s_prev = importable.Longform(text_input,
+ s_prev,
+ Kotodama,
+ alpha = alpha,
+ beta = beta,
+ t = t,
+ diffusion_steps=diffusion_steps, embedding_scale=embedding_scale, rate_of_speech=rate_of_speech)
+ wavs.append(wav)
+ wavs.append(np.zeros(int(silence)))
+
+ print('Synthesized: ')
+ return (24000, np.concatenate(wavs))
+
+
+
+def Inference_Synth_Prompt(text, description, Kotodama, alpha, beta, diffusion_steps, embedding_scale, rate_of_speech , progress=gr.Progress()):
+
+ if is_japanese(text):
+ text = text
+
+
+ else:
+ text = importable.p2g(smart_phonemize(text))
+
+ print('kimia activated! the converted text is: ', text)
+
+
+ prompt = f"""{description} \n text: {text}"""
+
+ print('prompt ->: ', prompt)
+
+ text = smart_phonemize(text)
+
+ print('phonemes ->: ', text)
+
+ Kotodama = importable.Kotodama_Prompter(importable.model, text=prompt, device=importable.device)
+
+ wav = importable.inference(text,
+ Kotodama,
+ alpha = alpha,
+ beta = beta,
+ diffusion_steps=diffusion_steps, embedding_scale=embedding_scale, rate_of_speech=rate_of_speech)
+
+ wav = importable.trim_long_silences(wav)
+
+
+ print('Synthesized: ')
+ return (24000, wav)
+
+with gr.Blocks() as audio_inf:
+ with gr.Row():
+ with gr.Column(scale=1):
+ inp = gr.Textbox(label="Text", info="Enter the text", value="きみの存在は、私の心の中で燃える小さな光のよう。きみがいない時、世界は白黒の写真みたいに寂しくて、何も輝いてない。きみの笑顔だけが、私の灰色の日々に色を塗ってくれる。離れてる時間は、めちゃくちゃ長く感じられて、きみへの想いは風船みたいにどんどん膨らんでいく。きみなしの世界なんて、想像できないよ。", interactive=True, scale=5)
+ voice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value=voicelist[5], interactive=True)
+ voice_2 = gr.Audio(label="Upload your own Audio", interactive=True, type='filepath', max_length=300, waveform_options={'waveform_color': '#a3ffc3', 'waveform_progress_color': '#e972ab'})
+
+ with gr.Accordion("Advanced Parameters", open=False):
+
+ alpha = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1, label="Alpha", info="a Diffusion sampler parameter handling the timbre, higher means less affected by the reference | 0 = diffusion is disabled", interactive=True)
+ beta = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1, label="Beta", info="a Diffusion sampler parameter, higher means less affected by the reference | 0 = diffusion is disabled", interactive=True)
+ multispeakersteps = gr.Slider(minimum=3, maximum=15, value=5, step=1, label="Diffusion Steps", interactive=True)
+ embscale = gr.Slider(minimum=1, maximum=5, value=1, step=0.1, label="Intensity", info="will impact the expressiveness, if you raise it too much it'll break.", interactive=True)
+ rate_of_speech = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.1, label="Rate of Speech", info="Higher -> Faster", interactive=True)
+
+ with gr.Column(scale=1):
+ btn = gr.Button("Synthesize", variant="primary")
+ audio = gr.Audio(interactive=False, label="Synthesized Audio", waveform_options={'waveform_color': '#a3ffc3', 'waveform_progress_color': '#e972ab'})
+ btn.click(Synthesize_Audio, inputs=[inp, voice, voice_2, multispeakersteps, embscale, alpha, beta, rate_of_speech], outputs=[audio], concurrency_limit=4)
+
+# Kotodama Text sampler Synthesis Block
+with gr.Blocks() as longform:
+ with gr.Row():
+ with gr.Column(scale=1):
+ inp_longform = gr.Textbox(
+ label="Text",
+ info="Enter the text [Speaker: Text] | Also works without any name.",
+ value=list(example_texts.values())[0],
+ interactive=True,
+ scale=5
+ )
+
+ with gr.Row():
+ example_dropdown = gr.Dropdown(
+ choices=list(example_texts.keys()),
+ label="Example Texts [pick one!]",
+ value=list(example_texts.keys())[0],
+ interactive=True
+ )
+
+ example_dropdown.change(
+ fn=update_text_input,
+ inputs=[example_dropdown],
+ outputs=[inp_longform]
+ )
+
+ with gr.Accordion("Advanced Parameters", open=False):
+
+ alpha_longform = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1,
+ label="Alpha",
+ info="a Diffusion parameter handling the timbre, higher means less affected by the reference | 0 = diffusion is disabled",
+ interactive=True)
+ beta_longform = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1,
+ label="Beta",
+ info="a Diffusion parameter, higher means less affected by the reference | 0 = diffusion is disabled",
+ interactive=True)
+ diffusion_steps_longform = gr.Slider(minimum=3, maximum=15, value=10, step=1,
+ label="Diffusion Steps",
+ interactive=True)
+ embedding_scale_longform = gr.Slider(minimum=1, maximum=5, value=1.25, step=0.1,
+ label="Intensity",
+ info="a Diffusion parameter, it will impact the expressiveness, if you raise it too much it'll break.",
+ interactive=True)
+
+ rate_of_speech_longform = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.1,
+ label="Rate of Speech",
+ info="Higher = Faster",
+ interactive=True)
+
+ with gr.Column(scale=1):
+ btn_longform = gr.Button("Synthesize", variant="primary")
+ audio_longform = gr.Audio(interactive=False,
+ label="Synthesized Audio",
+ waveform_options={'waveform_color': '#a3ffc3', 'waveform_progress_color': '#e972ab'})
+
+ btn_longform.click(LongformSynth_Text,
+ inputs=[inp_longform,
+ gr.State(None), # s_prev
+ gr.State(None), # Kotodama
+ alpha_longform,
+ beta_longform,
+ gr.State(.8), # t parameter
+ diffusion_steps_longform,
+ embedding_scale_longform,
+ rate_of_speech_longform],
+ outputs=[audio_longform],
+ concurrency_limit=4)
+
+# Kotodama prompt sampler Inference Block
+with gr.Blocks() as prompt_inference:
+ with gr.Row():
+ with gr.Column(scale=1):
+ text_prompt = gr.Textbox(
+ label="Text",
+ info="Enter the text to synthesize. This text will also be fed to the encoder. Make sure to see the Read Me for more details!",
+ value=inputs[0],
+ interactive=True,
+ scale=5
+ )
+ description_prompt = gr.Textbox(
+ label="Description",
+ info="Enter a highly detailed, descriptive prompt that matches the vibe of your text to guide the synthesis.",
+ value=prompts[0],
+ interactive=True,
+ scale=7
+ )
+
+ with gr.Row():
+ random_btn = gr.Button('Random Example', variant='secondary')
+
+ with gr.Accordion("Advanced Parameters", open=True):
+ embedding_scale_prompt = gr.Slider(minimum=1, maximum=5, value=1, step=0.25,
+ label="Intensity",
+ info="it will impact the expressiveness, if you raise it too much it'll break.",
+ interactive=True)
+ alpha_prompt = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1,
+ label="Alpha",
+ info="a Diffusion sampler parameter handling the timbre, higher means less affected by the reference | 0 = diffusion is disabled",
+ interactive=True)
+ beta_prompt = gr.Slider(minimum=0, maximum=1, value=0.0, step=0.1,
+ label="Beta",
+ info="a Diffusion sampler parameter, higher means less affected by the reference | 0 = diffusion is disabled",
+ interactive=True)
+ diffusion_steps_prompt = gr.Slider(minimum=3, maximum=15, value=10, step=1,
+ label="Diffusion Steps",
+ interactive=True)
+ rate_of_speech_prompt = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.1,
+ label="Rate of Speech",
+ info="Higher = Faster",
+ interactive=True)
+ with gr.Column(scale=1):
+ btn_prompt = gr.Button("Synthesize with Prompt", variant="primary")
+ audio_prompt = gr.Audio(interactive=False,
+ label="Prompt-based Synthesized Audio",
+ waveform_options={'waveform_color': '#a3ffc3', 'waveform_progress_color': '#e972ab'})
+
+
+ random_btn.click(
+ fn=get_random_prompt_pair,
+ inputs=[],
+ outputs=[text_prompt, description_prompt]
+ )
+
+ btn_prompt.click(Inference_Synth_Prompt,
+ inputs=[text_prompt,
+ description_prompt,
+ gr.State(None),
+ alpha_prompt,
+ beta_prompt,
+ diffusion_steps_prompt,
+ embedding_scale_prompt,
+ rate_of_speech_prompt],
+ outputs=[audio_prompt],
+ concurrency_limit=4)
+
+notes = """
+Notes
+
+
+This work is somewhat different from your typical speech model. It offers a high degree of control
+over the generation process, which means it's easy to inadvertently produce unimpressive outputs.
+
+
+
+Kotodama and the Diffusion sampler can significantly help guide the generation towards
+something that aligns with your input, but they aren't foolproof.
+
+
+
+The model's peak performance is achieved when the Diffusion sampler and Kotodama work seamlessly together.
+However, we won't see that level of performance here because this checkpoint is somewhat undertrained
+due to my time and resource constraints. (Tsumugi should be better in this regard,
+albeit if the diffusion works at all on your hardware.)
+Hopefully, you can further fine-tune this model (or train from scratch) to achieve even better results!
+
+
+
+The prompt encoder is also highly experimental and should be treated as a proof of concept. Due to the
+overwhelming ratio of female to male speakers and the wide variation in both speakers and their expressions,
+the prompt encoder may occasionally produce subpar or contradicting outputs. For example, high expressiveness alongside
+high pitch has been associated with females speakers simply because I had orders of magnitude more of them in the dataset.
+
+
+
+________________________________________________________
+A useful note about the voice design and prompting:
\n
+The vibe of the dialogue impacts the generated voice since the Japanese dialogue
+and the prompts were jointly trained. This is a peculiar feature of the Japanese lanuage.
+For example if you use 俺 (ore)、僕(boku) or your input is overall masculine
+you may get a guy's voice, even if you describe it as female in the prompt.
\n
+The Japanese text that is fed to the prompt doesn't necessarily have to be
+the same as your input, but we can't do it in this demo
+to not make the page too convoluted. In a real world scenario, you can just use a
+prompt with a suitable Japanese text to guide the model, get the style
+then move on to apply it to whatever dialogue you wish your model to speak.
+
+
+
+________________________________________________________
+
+The pitch information in my data was accurately calculated, but it only works in comparison to the other speakers
+so you may find a deep pitch may not be exactly too deep; although it actually is
+when you compare it to others within the same data, also some of the gender labels
+are inaccurate since we used a model to annotate them.
\n
+The main goal of this inference method is to demonstrate that style can be mapped to description's embeddings
+yielding reasonably good results.
+
+
+
+Overall, I'm confident that with a bit of experimentation, you can achieve reasonbaly good results.
+The model should work well out of the box 90% of the time without the need for extensive tweaking.
+However, here are some tips in case you encounter issues:
+
+
+Tips:
+
+
+ -
+ Ensure that your input closely matches your reference (audio or text prompt) in terms of tone,
+ non-verbal cues, duration, etc.
+
+
+ -
+ If your audio is too long but the input is too short, the speech rate will be slow, and vice versa.
+
+
+ -
+ Experiment with the alpha, beta, and Intensity parameters. The Diffusion
+ sampler is non-deterministic, so regenerate a few times if you're not satisfied with the output.
+
+
+ -
+ The speaker's share and expressive distribution in the dataset significantly impact the quality;
+ you won't necessarily get perfect results with all speakers.
+
+
+ -
+ Punctuation is very important, for example if you add «!» mark it will raise the voice or make it more intense.
+
+
+ -
+ Not all speakers are equal. Less represented speakers or out-of-distribution inputs may result
+ in artifacts.
+
+
+ -
+ If the Diffusion sampler works but the speaker didn't have a certain expression (e.g., extreme anger)
+ in the dataset, try raising the diffusion sampler's parameters and let it handle everything. Though
+ it may result in less speaker similarity, the ideal way to handle this is to cook new vectors by
+ transferring an emotion from one speaker to another. But you can't do that in this space.
+
+
+ -
+ For voice-based inference, you can use litagin's awesome Moe-speech dataset,
+ as part of the training data includes a portion of that.
+
+
+ -
+ you may also want to tweak the phonemes if you're going for something wild.
+ i have used cutlet in the backend, but that doesn't seem to like some of my mappings.
+
+
+
+
+"""
+
+
+notes_jp = """
+メモ
+
+
+この作業は、典型的なスピーチモデルとは少し異なります。生成プロセスに対して高い制御を提供するため、意図せずに
+比較的にクオリティーの低い出力を生成してしまうことが容易です。
+
+
+
+KotodamaとDiffusionサンプラーは、入力に沿ったものを生成するための大きな助けとなりますが、
+万全というわけではありません。
+
+
+
+モデルの最高性能は、DiffusionサンプラーとKotodamaがシームレスに連携することで達成されます。しかし、
+このチェックポイントは時間とリソースの制約からややTrain不足であるため、そのレベルの性能はここでは見られません。
+(この件について、「紬」のチェックポイントの方がいいかもしれません。でもまぁ、みなさんのハードに互換性があればね。)
+おそらく、このモデルをさらにFinetuningする(または最初からTrainする)ことで、より良い結果が得られるでしょう。
+
+
+_____________________________________________
\n
+音声デザインとプロンプトに関する有用なメモ:
+ダイアログの雰囲気は、日本語のダイアログとプロンプトが共同でTrainされたため、生成される音声に影響を与えます。
+これは日本語の特徴的な機能です。例えば、「俺」や「僕」を使用したり、全体的に男性らしい入力をすると、
+プロンプトで女性と記述していても、男性の声が得られる可能性があります。
+プロンプトに入力される日本語のテキストは、必ずしも入力内容と同じである必要はありませんが、
+このデモではページが複雑になりすぎないようにそれを行うことはできません。
+実際のシナリオでは、適切な日本語のテキストを含むプロンプトを使用してモデルを導き、
+スタイルを取得した後、それを希望するダイアログに適用することができます。
+
+_____________________________________________
\n
+
+
+プロンプトエンコーダも非常に実験的であり、概念実証として扱うべきです。女性話者対男性話者の比率が圧倒的で、
+また話者とその表現に大きなバリエーションがあるため、エンコーダは質の低い出力を生成する可能性があります。
+例えば、高い表現力は、データセットに多く含まれていた女性話者と関連付けられています。
+それに、データのピッチ情報は正確に計算されましたが、それは他のスピーカーとの比較でしか機能しません...
+だから、深いピッチが必ずしも深すぎるわけではないことに気づくかもしれません。
+ただし、実際には、同じデータ内の他の人と比較すると、深すぎます。このインフレンスの主な目的は、
+スタイルベクトルを記述にマッピングし、合理的に良い結果を得ることにあります。
+
+
+
+全体として、少しの実験でほぼ望む結果を達成できると自信を持っています。90%のケースで、大幅な調整を必要とせず、
+そのままでうまく動作するはずです。しかし、問題が発生した場合のためにいくつかのヒントがあります:
+
+
+ヒント:
+
+
+ -
+ 入力がリファレンス(音声またはテキストプロンプト)とトーン、非言語的な手がかり、
+ 長さなどで密接に一致していることを確認してください。
+
+
+ -
+ 音声が長すぎるが入力が短すぎる場合、話速が遅くなります。その逆もまた同様です。
+
+
+ -
+ アルファ、ベータ、および埋め込みスケールのパラメータを試行錯誤してください。Diffusionサンプラーは
+ 非決定的なので、満足のいく出力が得られない場合は何度か再生成してください。
+
+
+ -
+ データセット内の話者の分布と表現力の分布は品質に大きく影響します。
+ すべての話者で必ずしも完璧な結果が得られるわけではありません。
+
+
+ -
+ 句読点は重要です。たとえな、「!」を使えば、スタイルのインテンシティが上がります。
+
+
+ -
+ すべての話者が平等に表現されているわけではありません。少ない表現の話者や
+ 分布外の入力はアーティファクトを生じさせる可能性があります。
+
+
+ -
+ Diffusionサンプラーが機能しているが、データセット内で特定の表現(例:極度の怒り)がない場合、
+ Diffusionサンプラーのパラメータを引き上げ、サンプラーにすべてを任せてください。ただし、それにより
+ 話者の類似性が低下する可能性があります。この問題を理想的に解決する方法は、ある話者から別の話者に
+ 感情を転送し新しいベクトルを作成することですが、ここではできません。
+
+
+ -
+ 音声ベースのインフレンスには、トレーニングデータの一部としてMoe-speechデータセットの一部を含む
+ litaginの素晴らしいデータセットを使用できます。
+
+
+ -
+ たまには音素の調整が必要になる場合もあります。バックエンドではcutletを使っているのですが、
+ いくつかのOODマッピングがcutletと相性が良くないみたいです。
+
+
+
+"""
+with gr.Blocks() as read_me:
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown(notes)
+
+with gr.Blocks() as read_me_jp:
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown(notes_jp)
+
+
+custom_css = """
+.tab-label {
+ color: #FFD700 !important;
+}
+"""
+
+
+
+
+with gr.Blocks(title="Tsukasa 司", css=custom_css + "footer{display:none !important}", theme="Respair/Shiki@1.2.1") as demo:
+ # gr.DuplicateButton("Duplicate Space")
+ gr.Markdown(INTROTXT)
+
+
+ gr.TabbedInterface([longform, audio_inf, prompt_inference, read_me, read_me_jp],
+ ['Kotodama Text Inference', 'Voice-guided Inference','Prompt-guided Inference [Highly Experimental - not optimized]', 'Read Me! [English]', 'Read Me! [日本語]'])
+
+if __name__ == "__main__":
+ demo.queue(api_open=False, max_size=15).launch(show_api=False, share=True)
diff --git a/importable.py b/importable.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa3ca5a27a45b4e4793101e1dd95b8c9c7c7e171
--- /dev/null
+++ b/importable.py
@@ -0,0 +1,423 @@
+print("NLTK")
+import nltk
+nltk.download('punkt')
+print("SCIPY")
+from scipy.io.wavfile import write
+print("TORCH STUFF")
+import torch
+print("START")
+torch.manual_seed(0)
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+
+
+# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+# import torch
+# print(torch.cuda.device_count())
+import IPython.display as ipd
+import os
+os.environ['CUDA_HOME'] = '/home/ubuntu/miniconda3/envs/respair/lib/python3.11/site-packages/torch/lib/include/cuda'
+import torch
+torch.manual_seed(0)
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+
+import random
+random.seed(0)
+
+import numpy as np
+np.random.seed(0)
+
+# load packages
+from text_utils import TextCleaner
+textclenaer = TextCleaner()
+
+
+def length_to_mask(lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+
+
+
+import time
+import random
+import yaml
+from munch import Munch
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+from nltk.tokenize import word_tokenize
+
+from models import *
+from Modules.KotoDama_sampler import tokenizer_koto_prompt, tokenizer_koto_text
+from utils import *
+
+import nltk
+nltk.download('punkt_tab')
+
+from nltk.tokenize import sent_tokenize
+
+from konoha import SentenceTokenizer
+
+
+sent_tokenizer = SentenceTokenizer()
+
+# %matplotlib inline
+to_mel = torchaudio.transforms.MelSpectrogram(
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
+mean, std = -4, 4
+
+
+def preprocess(wave):
+ wave_tensor = torch.from_numpy(wave).float()
+ mel_tensor = to_mel(wave_tensor)
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
+ return mel_tensor
+
+def compute_style_through_clip(path):
+ wave, sr = librosa.load(path, sr=24000)
+ audio, index = librosa.effects.trim(wave, top_db=30)
+ if sr != 24000:
+ audio = librosa.resample(audio, sr, 24000)
+ mel_tensor = preprocess(audio).to(device)
+
+ with torch.no_grad():
+ ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
+ ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
+
+ return torch.cat([ref_s, ref_p], dim=1)
+
+
+def Kotodama_Prompter(model, text, device):
+
+ with torch.no_grad():
+ style = model.KotoDama_Prompt(**tokenizer_koto_prompt(text, return_tensors="pt").to(device))['logits']
+ return style
+
+def Kotodama_Sampler(model, text, device):
+
+ with torch.no_grad():
+ style = model.KotoDama_Text(**tokenizer_koto_text(text, return_tensors="pt").to(device))['logits']
+ return style
+
+
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+config = yaml.safe_load(open("Configs/config_kanade.yml"))
+
+# load pretrained ASR model
+ASR_config = config.get('ASR_config', False)
+ASR_path = config.get('ASR_path', False)
+text_aligner = load_ASR_models(ASR_path, ASR_config)
+
+
+KotoDama_Prompter = load_KotoDama_Prompter(path="Utils/KTD/prompt_enc/checkpoint-73285")
+KotoDama_TextSampler = load_KotoDama_TextSampler(path="Utils/KTD/text_enc/checkpoint-22680")
+
+# load pretrained F0 model
+F0_path = config.get('F0_path', False)
+pitch_extractor = load_F0_models(F0_path)
+
+# load BERT model
+from Utils.PLBERT.util import load_plbert
+BERT_path = config.get('PLBERT_dir', False)
+plbert = load_plbert(BERT_path)
+
+model_params = recursive_munch(config['model_params'])
+model = build_model(model_params, text_aligner, pitch_extractor, plbert, KotoDama_Prompter, KotoDama_TextSampler)
+_ = [model[key].eval() for key in model]
+_ = [model[key].to(device) for key in model]
+
+params_whole = torch.load("Models/Style_Tsukasa_v02/Top_ckpt_24khz.pth", map_location='cpu')
+params = params_whole['net']
+
+
+for key in model:
+ if key in params:
+ print('%s loaded' % key)
+ try:
+ model[key].load_state_dict(params[key])
+ except:
+ from collections import OrderedDict
+ state_dict = params[key]
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = k[7:] # remove `module.`
+ new_state_dict[name] = v
+ # load params
+ model[key].load_state_dict(new_state_dict, strict=False)
+# except:
+# _load(params[key], model[key])
+
+
+_ = [model[key].eval() for key in model]
+
+
+from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
+diffusion_sampler = DiffusionSampler(
+ model.diffusion.diffusion,
+ sampler=ADPM2Sampler(),
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
+ clamp=False
+)
+
+def inference(text=None, ref_s=None, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.):
+
+ tokens = textclenaer(text)
+ tokens.insert(0, 0)
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
+
+ with torch.no_grad():
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
+
+ text_mask = length_to_mask(input_lengths).to(device)
+
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+
+
+
+ s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=embedding_scale,
+ features=ref_s, # reference from the same speaker as the embedding
+ num_steps=diffusion_steps).squeeze(1)
+
+
+ s = s_pred[:, 128:]
+ ref = s_pred[:, :128]
+
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
+
+ d = model.predictor.text_encoder(d_en,
+ s, input_lengths, text_mask)
+
+
+
+ x = model.predictor.lstm(d)
+ x_mod = model.predictor.prepare_projection(x)
+ duration = model.predictor.duration_proj(x_mod)
+
+
+ duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech
+
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
+
+
+
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
+
+ c_frame = 0
+ for i in range(pred_aln_trg.size(0)):
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
+ c_frame += int(pred_dur[i].data)
+
+ # encode prosody
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
+
+
+
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
+
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
+
+
+ out = model.decoder(asr,
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
+
+
+ return out.squeeze().cpu().numpy()[..., :-50]
+
+
+def Longform(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.0):
+
+
+ tokens = textclenaer(text)
+ tokens.insert(0, 0)
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
+
+ with torch.no_grad():
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
+ text_mask = length_to_mask(input_lengths).to(device)
+
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
+
+ s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
+ embedding=bert_dur,
+ embedding_scale=embedding_scale,
+ features=ref_s,
+ num_steps=diffusion_steps).squeeze(1)
+
+ if s_prev is not None:
+ # convex combination of previous and current style
+ s_pred = t * s_prev + (1 - t) * s_pred
+
+ s = s_pred[:, 128:]
+ ref = s_pred[:, :128]
+
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
+
+ s_pred = torch.cat([ref, s], dim=-1)
+
+ d = model.predictor.text_encoder(d_en,
+ s, input_lengths, text_mask)
+
+ x = model.predictor.lstm(d)
+ x_mod = model.predictor.prepare_projection(x) # 640 -> 512
+ duration = model.predictor.duration_proj(x_mod)
+
+ duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
+
+
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
+ c_frame = 0
+ for i in range(pred_aln_trg.size(0)):
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
+ c_frame += int(pred_dur[i].data)
+
+ # encode prosody
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
+
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
+
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
+
+ out = model.decoder(asr,
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
+
+
+ return out.squeeze().cpu().numpy()[..., :-100], s_pred
+
+
+
+def trim_long_silences(wav_data, sample_rate=24000, silence_threshold=0.01, min_silence_duration=0.8):
+
+
+ min_silence_samples = int(min_silence_duration * sample_rate)
+
+
+ envelope = np.abs(wav_data)
+
+
+ silence_mask = envelope < silence_threshold
+
+
+ silence_changes = np.diff(silence_mask.astype(int))
+ silence_starts = np.where(silence_changes == 1)[0] + 1
+ silence_ends = np.where(silence_changes == -1)[0] + 1
+
+
+ if silence_mask[0]:
+ silence_starts = np.concatenate(([0], silence_starts))
+ if silence_mask[-1]:
+ silence_ends = np.concatenate((silence_ends, [len(wav_data)]))
+
+
+ if len(silence_starts) == 0 or len(silence_ends) == 0:
+ return wav_data
+
+ processed_segments = []
+ last_end = 0
+
+ for start, end in zip(silence_starts, silence_ends):
+
+ processed_segments.append(wav_data[last_end:start])
+
+
+ silence_duration = end - start
+
+ if silence_duration > min_silence_samples:
+
+ silence_segment = np.zeros(min_silence_samples)
+
+ fade_samples = min(1000, min_silence_samples // 4)
+ fade_in = np.linspace(0, 1, fade_samples)
+ fade_out = np.linspace(1, 0, fade_samples)
+ silence_segment[:fade_samples] *= fade_in
+ silence_segment[-fade_samples:] *= fade_out
+ processed_segments.append(silence_segment)
+ else:
+
+ processed_segments.append(wav_data[start:end])
+
+ last_end = end
+
+
+ if last_end < len(wav_data):
+ processed_segments.append(wav_data[last_end:])
+
+
+ return np.concatenate(processed_segments)
+
+
+def merge_short_elements(lst):
+ i = 0
+ while i < len(lst):
+ if i > 0 and len(lst[i]) < 10:
+ lst[i-1] += ' ' + lst[i]
+ lst.pop(i)
+ else:
+ i += 1
+ return lst
+
+
+def merge_three(text_list, maxim=2):
+
+ merged_list = []
+ for i in range(0, len(text_list), maxim):
+ merged_text = ' '.join(text_list[i:i+maxim])
+ merged_list.append(merged_text)
+ return merged_list
+
+
+def merging_sentences(lst):
+ return merge_three(merge_short_elements(lst))
+
+
+import os
+
+
+from openai import OpenAI
+
+
+openai_api_key = "EMPTY"
+openai_api_base = "http://localhost:8000/v1"
+
+client = OpenAI(
+ api_key=openai_api_key,
+ base_url=openai_api_base,
+)
+
+model_name = "Respair/Japanese_Phoneme_to_Grapheme_LLM"
+
+
+def p2g(param):
+
+ chat_response = client.chat.completions.create(
+
+ model=model_name,
+ max_tokens=512,
+ temperature=0.1,
+
+
+ messages=[
+
+ {"role": "user", "content": f"convert this pronunciation back to normal japanese if you see one, otherwise copy the same thing: {param}"}]
+ )
+
+ result = chat_response.choices[0].message.content
+ # if " " in result:
+ # result = result.replace(" "," ")
+
+ return result.lstrip()
\ No newline at end of file
diff --git a/losses.py b/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..135807afaf698162e3d5bb577072246f5835ae80
--- /dev/null
+++ b/losses.py
@@ -0,0 +1,303 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+from transformers import AutoModel
+
+
+class SpectralConvergengeLoss(torch.nn.Module):
+ """Spectral convergence loss module."""
+
+ def __init__(self):
+ """Initilize spectral convergence loss module."""
+ super(SpectralConvergengeLoss, self).__init__()
+
+ def forward(self, x_mag, y_mag):
+ """Calculate forward propagation.
+ Args:
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+ Returns:
+ Tensor: Spectral convergence loss value.
+ """
+ return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
+
+
+class STFTLoss(torch.nn.Module):
+ """STFT loss module."""
+
+ def __init__(
+ self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window
+ ):
+ """Initialize STFT loss module."""
+ super(STFTLoss, self).__init__()
+ self.fft_size = fft_size
+ self.shift_size = shift_size
+ self.win_length = win_length
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
+ sample_rate=24000,
+ n_fft=fft_size,
+ win_length=win_length,
+ hop_length=shift_size,
+ window_fn=window,
+ )
+
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+ Returns:
+ Tensor: Spectral convergence loss value.
+ Tensor: Log STFT magnitude loss value.
+ """
+ x_mag = self.to_mel(x)
+ mean, std = -4, 4
+ x_mag = (torch.log(1e-5 + x_mag) - mean) / std
+
+ y_mag = self.to_mel(y)
+ mean, std = -4, 4
+ y_mag = (torch.log(1e-5 + y_mag) - mean) / std
+
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+ return sc_loss
+
+
+class MultiResolutionSTFTLoss(torch.nn.Module):
+ """Multi resolution STFT loss module."""
+
+ def __init__(
+ self,
+ fft_sizes=[1024, 2048, 512],
+ hop_sizes=[120, 240, 50],
+ win_lengths=[600, 1200, 240],
+ window=torch.hann_window,
+ ):
+ """Initialize Multi resolution STFT loss module.
+ Args:
+ fft_sizes (list): List of FFT sizes.
+ hop_sizes (list): List of hop sizes.
+ win_lengths (list): List of window lengths.
+ window (str): Window function type.
+ """
+ super(MultiResolutionSTFTLoss, self).__init__()
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
+ self.stft_losses = torch.nn.ModuleList()
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+ Returns:
+ Tensor: Multi resolution spectral convergence loss value.
+ Tensor: Multi resolution log STFT magnitude loss value.
+ """
+ sc_loss = 0.0
+ for f in self.stft_losses:
+ sc_l = f(x, y)
+ sc_loss += sc_l
+ sc_loss /= len(self.stft_losses)
+
+ return sc_loss
+
+
+def feature_loss(fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss * 2
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean((1 - dr) ** 2)
+ g_loss = torch.mean(dg**2)
+ loss += r_loss + g_loss
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ l = torch.mean((1 - dg) ** 2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
+
+""" https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
+
+
+def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ tau = 0.04
+ m_DG = torch.median((dr - dg))
+ L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
+ loss += tau - F.relu(tau - L_rel)
+ return loss
+
+
+def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
+ tau = 0.04
+ m_DG = torch.median((dr - dg))
+ L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
+ loss += tau - F.relu(tau - L_rel)
+ return loss
+
+
+class GeneratorLoss(torch.nn.Module):
+ def __init__(self, mpd, msd):
+ super(GeneratorLoss, self).__init__()
+ self.mpd = mpd
+ self.msd = msd
+
+ def forward(self, y, y_hat):
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
+ y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
+ loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
+ loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
+ loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
+ loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
+
+ loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(
+ y_ds_hat_r, y_ds_hat_g
+ )
+
+ loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
+
+ return loss_gen_all.mean()
+
+
+class DiscriminatorLoss(torch.nn.Module):
+ def __init__(self, mpd, msd):
+ super(DiscriminatorLoss, self).__init__()
+ self.mpd = mpd
+ self.msd = msd
+
+ def forward(self, y, y_hat):
+ # MPD
+ y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
+ y_df_hat_r, y_df_hat_g
+ )
+ # MSD
+ y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
+ y_ds_hat_r, y_ds_hat_g
+ )
+
+ loss_rel = discriminator_TPRLS_loss(
+ y_df_hat_r, y_df_hat_g
+ ) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
+
+ d_loss = loss_disc_s + loss_disc_f + loss_rel
+
+ return d_loss.mean()
+
+
+class WavLMLoss(torch.nn.Module):
+ def __init__(self, model, wd, model_sr, slm_sr=16000):
+ super(WavLMLoss, self).__init__()
+ self.wavlm = AutoModel.from_pretrained(model)
+ self.wd = wd
+ self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
+
+ def forward(self, wav, y_rec):
+ with torch.no_grad():
+ wav_16 = self.resample(wav)
+ wav_embeddings = self.wavlm(
+ input_values=wav_16, output_hidden_states=True
+ ).hidden_states
+ y_rec_16 = self.resample(y_rec)
+ y_rec_embeddings = self.wavlm(
+ input_values=y_rec_16.squeeze(), output_hidden_states=True
+ ).hidden_states
+
+ floss = 0
+ for er, eg in zip(wav_embeddings, y_rec_embeddings):
+ floss += torch.mean(torch.abs(er - eg))
+
+ return floss.mean()
+
+ def generator(self, y_rec):
+ y_rec_16 = self.resample(y_rec)
+ y_rec_embeddings = self.wavlm(
+ input_values=y_rec_16, output_hidden_states=True
+ ).hidden_states
+ y_rec_embeddings = (
+ torch.stack(y_rec_embeddings, dim=1)
+ .transpose(-1, -2)
+ .flatten(start_dim=1, end_dim=2)
+ )
+ y_df_hat_g = self.wd(y_rec_embeddings)
+ loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
+
+ return loss_gen
+
+ def discriminator(self, wav, y_rec):
+ with torch.no_grad():
+ wav_16 = self.resample(wav)
+ wav_embeddings = self.wavlm(
+ input_values=wav_16, output_hidden_states=True
+ ).hidden_states
+ y_rec_16 = self.resample(y_rec)
+ y_rec_embeddings = self.wavlm(
+ input_values=y_rec_16, output_hidden_states=True
+ ).hidden_states
+
+ y_embeddings = (
+ torch.stack(wav_embeddings, dim=1)
+ .transpose(-1, -2)
+ .flatten(start_dim=1, end_dim=2)
+ )
+ y_rec_embeddings = (
+ torch.stack(y_rec_embeddings, dim=1)
+ .transpose(-1, -2)
+ .flatten(start_dim=1, end_dim=2)
+ )
+
+ y_d_rs = self.wd(y_embeddings)
+ y_d_gs = self.wd(y_rec_embeddings)
+
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
+
+ r_loss = torch.mean((1 - y_df_hat_r) ** 2)
+ g_loss = torch.mean((y_df_hat_g) ** 2)
+
+ loss_disc_f = r_loss + g_loss
+
+ return loss_disc_f.mean()
+
+ def discriminator_forward(self, wav):
+ with torch.no_grad():
+ wav_16 = self.resample(wav)
+ wav_embeddings = self.wavlm(
+ input_values=wav_16, output_hidden_states=True
+ ).hidden_states
+ y_embeddings = (
+ torch.stack(wav_embeddings, dim=1)
+ .transpose(-1, -2)
+ .flatten(start_dim=1, end_dim=2)
+ )
+
+ y_d_rs = self.wd(y_embeddings)
+
+ return y_d_rs
diff --git a/meldataset.py b/meldataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e845d90373d97627b261ba11ff2adb3736a24405
--- /dev/null
+++ b/meldataset.py
@@ -0,0 +1,256 @@
+#coding: utf-8
+import os
+import os.path as osp
+import time
+import random
+import numpy as np
+import random
+import soundfile as sf
+import librosa
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+from torch.utils.data import DataLoader
+
+import logging
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+import pandas as pd
+
+_pad = "$"
+_punctuation = ';:,.!?¡¿—…"«»“” '
+_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
+_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
+
+# Export all symbols:
+symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
+
+dicts = {}
+for i in range(len((symbols))):
+ dicts[symbols[i]] = i
+
+class TextCleaner:
+ def __init__(self, dummy=None):
+ self.word_index_dictionary = dicts
+ def __call__(self, text):
+ indexes = []
+ for char in text:
+ try:
+ indexes.append(self.word_index_dictionary[char])
+ except KeyError:
+ print(text)
+ return indexes
+
+np.random.seed(1)
+random.seed(1)
+SPECT_PARAMS = {
+ "n_fft": 2048,
+ "win_length": 1200,
+ "hop_length": 300
+}
+MEL_PARAMS = {
+ "n_mels": 80,
+}
+
+to_mel = torchaudio.transforms.MelSpectrogram(
+
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
+mean, std = -4, 4
+
+def preprocess(wave):
+ wave_tensor = torch.from_numpy(wave).float()
+ mel_tensor = to_mel(wave_tensor)
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
+ return mel_tensor
+
+class FilePathDataset(torch.utils.data.Dataset):
+ def __init__(self,
+ data_list,
+ root_path,
+ sr=24000,
+ data_augmentation=False,
+ validation=False,
+ OOD_data="Data/OOD_texts.txt",
+ min_length=50,
+ ):
+
+ spect_params = SPECT_PARAMS
+ mel_params = MEL_PARAMS
+
+ _data_list = [l.strip().split('|') for l in data_list]
+ self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
+ self.text_cleaner = TextCleaner()
+ self.sr = sr
+
+ self.df = pd.DataFrame(self.data_list)
+
+ self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
+
+ self.mean, self.std = -4, 4
+ self.data_augmentation = data_augmentation and (not validation)
+ self.max_mel_length = 192
+
+ self.min_length = min_length
+ with open(OOD_data, 'r', encoding='utf-8') as f:
+ tl = f.readlines()
+ idx = 1 if '.wav' in tl[0].split('|')[0] else 0
+ self.ptexts = [t.split('|')[idx] for t in tl]
+
+ self.root_path = root_path
+
+ def __len__(self):
+ return len(self.data_list)
+
+ def __getitem__(self, idx):
+ data = self.data_list[idx]
+ path = data[0]
+
+ wave, text_tensor, speaker_id = self._load_tensor(data)
+
+ mel_tensor = preprocess(wave).squeeze()
+
+ acoustic_feature = mel_tensor.squeeze()
+ length_feature = acoustic_feature.size(1)
+ acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)]
+
+ # get reference sample
+ ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist()
+ ref_mel_tensor, ref_label = self._load_data(ref_data[:3])
+
+ # get OOD text
+
+ ps = ""
+
+ while len(ps) < self.min_length:
+ rand_idx = np.random.randint(0, len(self.ptexts) - 1)
+ ps = self.ptexts[rand_idx]
+
+ text = self.text_cleaner(ps)
+ text.insert(0, 0)
+ text.append(0)
+
+ ref_text = torch.LongTensor(text)
+
+ return speaker_id, acoustic_feature, text_tensor, ref_text, ref_mel_tensor, ref_label, path, wave
+
+ def _load_tensor(self, data):
+ wave_path, text, speaker_id = data
+ speaker_id = int(speaker_id)
+ wave, sr = sf.read(osp.join(self.root_path, wave_path))
+ if wave.shape[-1] == 2:
+ wave = wave[:, 0].squeeze()
+ if sr != 24000:
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
+ print(wave_path, sr)
+
+ wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
+
+ text = self.text_cleaner(text)
+
+ text.insert(0, 0)
+ text.append(0)
+
+ text = torch.LongTensor(text)
+
+ return wave, text, speaker_id
+
+ def _load_data(self, data):
+ wave, text_tensor, speaker_id = self._load_tensor(data)
+ mel_tensor = preprocess(wave).squeeze()
+
+ mel_length = mel_tensor.size(1)
+ if mel_length > self.max_mel_length:
+ random_start = np.random.randint(0, mel_length - self.max_mel_length)
+ mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length]
+
+ return mel_tensor, speaker_id
+
+
+class Collater(object):
+ """
+ Args:
+ adaptive_batch_size (bool): if true, decrease batch size when long data comes.
+ """
+
+ def __init__(self, return_wave=False):
+ self.text_pad_index = 0
+ self.min_mel_length = 192
+ self.max_mel_length = 192
+ self.return_wave = return_wave
+
+
+ def __call__(self, batch):
+ # batch[0] = wave, mel, text, f0, speakerid
+ batch_size = len(batch)
+
+ # sort by mel length
+ lengths = [b[1].shape[1] for b in batch]
+ batch_indexes = np.argsort(lengths)[::-1]
+ batch = [batch[bid] for bid in batch_indexes]
+
+ nmels = batch[0][1].size(0)
+ max_mel_length = max([b[1].shape[1] for b in batch])
+ max_text_length = max([b[2].shape[0] for b in batch])
+ max_rtext_length = max([b[3].shape[0] for b in batch])
+
+ labels = torch.zeros((batch_size)).long()
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
+ texts = torch.zeros((batch_size, max_text_length)).long()
+ ref_texts = torch.zeros((batch_size, max_rtext_length)).long()
+
+ input_lengths = torch.zeros(batch_size).long()
+ ref_lengths = torch.zeros(batch_size).long()
+ output_lengths = torch.zeros(batch_size).long()
+ ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float()
+ ref_labels = torch.zeros((batch_size)).long()
+ paths = ['' for _ in range(batch_size)]
+ waves = [None for _ in range(batch_size)]
+
+ for bid, (label, mel, text, ref_text, ref_mel, ref_label, path, wave) in enumerate(batch):
+ mel_size = mel.size(1)
+ text_size = text.size(0)
+ rtext_size = ref_text.size(0)
+ labels[bid] = label
+ mels[bid, :, :mel_size] = mel
+ texts[bid, :text_size] = text
+ ref_texts[bid, :rtext_size] = ref_text
+ input_lengths[bid] = text_size
+ ref_lengths[bid] = rtext_size
+ output_lengths[bid] = mel_size
+ paths[bid] = path
+ ref_mel_size = ref_mel.size(1)
+ ref_mels[bid, :, :ref_mel_size] = ref_mel
+
+ ref_labels[bid] = ref_label
+ waves[bid] = wave
+
+ return waves, texts, input_lengths, ref_texts, ref_lengths, mels, output_lengths, ref_mels
+
+
+
+def build_dataloader(path_list,
+ root_path,
+ validation=False,
+ OOD_data="Data/OOD_texts.txt",
+ min_length=50,
+ batch_size=4,
+ num_workers=1,
+ device='cpu',
+ collate_config={},
+ dataset_config={}):
+
+ dataset = FilePathDataset(path_list, root_path, OOD_data=OOD_data, min_length=min_length, validation=validation, **dataset_config)
+ collate_fn = Collater(**collate_config)
+ data_loader = DataLoader(dataset,
+ batch_size=batch_size,
+ shuffle=(not validation),
+ num_workers=num_workers,
+ drop_last=True,
+ collate_fn=collate_fn,
+ pin_memory=(device != 'cpu'))
+
+ return data_loader
+
diff --git a/models.py b/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..046c7b97505a7f41caf5bfd8ddb3b10c04d24df8
--- /dev/null
+++ b/models.py
@@ -0,0 +1,1022 @@
+import os
+import os.path as osp
+
+import copy
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+from Utils.ASR.models import ASRCNN
+from Utils.JDC.model import JDCNet
+
+
+from transformers import AutoModelForSequenceClassification, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
+
+from Modules.KotoDama_sampler import KotoDama_Prompt, KotoDama_Text
+from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
+from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
+from Modules.diffusion.diffusion import AudioDiffusionConditional
+from Modules.diffusion.audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler, DiffusionUpsampler
+
+from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
+
+from munch import Munch
+import yaml
+
+# from hflayers import Hopfield, HopfieldPooling, HopfieldLayer
+# from hflayers.auxiliary.data import BitPatternSet
+
+# Import auxiliary modules.
+from distutils.version import LooseVersion
+from typing import List, Tuple
+
+import math
+# from liger_kernel.ops.layer_norm import LigerLayerNormFunction
+# from liger_kernel.transformers.experimental.embedding import nn.Embedding
+
+import torch
+
+from xlstm import (
+ xLSTMBlockStack,
+ xLSTMBlockStackConfig,
+ mLSTMBlockConfig,
+ mLSTMLayerConfig,
+ sLSTMBlockConfig,
+ sLSTMLayerConfig,
+ FeedForwardConfig,
+)
+
+
+
+class LearnedDownSample(nn.Module):
+ def __init__(self, layer_type, dim_in):
+ super().__init__()
+ self.layer_type = layer_type
+
+ if self.layer_type == 'none':
+ self.conv = nn.Identity()
+ elif self.layer_type == 'timepreserve':
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
+ elif self.layer_type == 'half':
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
+ else:
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+ def forward(self, x):
+ return self.conv(x)
+
+class LearnedUpSample(nn.Module):
+ def __init__(self, layer_type, dim_in):
+ super().__init__()
+ self.layer_type = layer_type
+
+ if self.layer_type == 'none':
+ self.conv = nn.Identity()
+ elif self.layer_type == 'timepreserve':
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
+ elif self.layer_type == 'half':
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
+ else:
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+ def forward(self, x):
+ return self.conv(x)
+
+class DownSample(nn.Module):
+ def __init__(self, layer_type):
+ super().__init__()
+ self.layer_type = layer_type
+
+ def forward(self, x):
+ if self.layer_type == 'none':
+ return x
+ elif self.layer_type == 'timepreserve':
+ return F.avg_pool2d(x, (2, 1))
+ elif self.layer_type == 'half':
+ if x.shape[-1] % 2 != 0:
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
+ return F.avg_pool2d(x, 2)
+ else:
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+class UpSample(nn.Module):
+ def __init__(self, layer_type):
+ super().__init__()
+ self.layer_type = layer_type
+
+ def forward(self, x):
+ if self.layer_type == 'none':
+ return x
+ elif self.layer_type == 'timepreserve':
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
+ elif self.layer_type == 'half':
+ return F.interpolate(x, scale_factor=2, mode='nearest')
+ else:
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
+
+
+class ResBlk(nn.Module):
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
+ normalize=False, downsample='none'):
+ super().__init__()
+ self.actv = actv
+ self.normalize = normalize
+ self.downsample = DownSample(downsample)
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
+ self.learned_sc = dim_in != dim_out
+ self._build_weights(dim_in, dim_out)
+
+ def _build_weights(self, dim_in, dim_out):
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
+ if self.normalize:
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
+ if self.learned_sc:
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+ def _shortcut(self, x):
+ if self.learned_sc:
+ x = self.conv1x1(x)
+ if self.downsample:
+ x = self.downsample(x)
+ return x
+
+ def _residual(self, x):
+ if self.normalize:
+ x = self.norm1(x)
+ x = self.actv(x)
+ x = self.conv1(x)
+ x = self.downsample_res(x)
+ if self.normalize:
+ x = self.norm2(x)
+ x = self.actv(x)
+ x = self.conv2(x)
+ return x
+
+ def forward(self, x):
+ x = self._shortcut(x) + self._residual(x)
+ return x / math.sqrt(2) # unit variance
+
+class StyleEncoder(nn.Module):
+ def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
+ super().__init__()
+ blocks = []
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
+
+ repeat_num = 4
+ for _ in range(repeat_num):
+ dim_out = min(dim_in*2, max_conv_dim)
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
+ dim_in = dim_out
+
+ blocks += [nn.LeakyReLU(0.2)]
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
+ blocks += [nn.AdaptiveAvgPool2d(1)]
+ blocks += [nn.LeakyReLU(0.2)]
+ self.shared = nn.Sequential(*blocks)
+
+ self.unshared = nn.Linear(dim_out, style_dim)
+
+ def forward(self, x):
+ h = self.shared(x)
+ h = h.view(h.size(0), -1)
+ s = self.unshared(h)
+
+ return s
+
+class LinearNorm(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+ super(LinearNorm, self).__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.linear_layer.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+class Discriminator2d(nn.Module):
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
+ super().__init__()
+ blocks = []
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
+
+ for lid in range(repeat_num):
+ dim_out = min(dim_in*2, max_conv_dim)
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
+ dim_in = dim_out
+
+ blocks += [nn.LeakyReLU(0.2)]
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
+ blocks += [nn.LeakyReLU(0.2)]
+ blocks += [nn.AdaptiveAvgPool2d(1)]
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
+ self.main = nn.Sequential(*blocks)
+
+ def get_feature(self, x):
+ features = []
+ for l in self.main:
+ x = l(x)
+ features.append(x)
+ out = features[-1]
+ out = out.view(out.size(0), -1) # (batch, num_domains)
+ return out, features
+
+ def forward(self, x):
+ out, features = self.get_feature(x)
+ out = out.squeeze() # (batch)
+ return out, features
+
+class ResBlk1d(nn.Module):
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
+ normalize=False, downsample='none', dropout_p=0.2):
+ super().__init__()
+ self.actv = actv
+ self.normalize = normalize
+ self.downsample_type = downsample
+ self.learned_sc = dim_in != dim_out
+ self._build_weights(dim_in, dim_out)
+ self.dropout_p = dropout_p
+
+ if self.downsample_type == 'none':
+ self.pool = nn.Identity()
+ else:
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
+
+ def _build_weights(self, dim_in, dim_out):
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+ if self.normalize:
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
+ if self.learned_sc:
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+ def downsample(self, x):
+ if self.downsample_type == 'none':
+ return x
+ else:
+ if x.shape[-1] % 2 != 0:
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
+ return F.avg_pool1d(x, 2)
+
+ def _shortcut(self, x):
+ if self.learned_sc:
+ x = self.conv1x1(x)
+ x = self.downsample(x)
+ return x
+
+ def _residual(self, x):
+ if self.normalize:
+ x = self.norm1(x)
+ x = self.actv(x)
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
+
+ x = self.conv1(x)
+ x = self.pool(x)
+ if self.normalize:
+ x = self.norm2(x)
+
+ x = self.actv(x)
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
+
+ x = self.conv2(x)
+ return x
+
+ def forward(self, x):
+ x = self._shortcut(x) + self._residual(x)
+ return x / math.sqrt(2) # unit variance
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+class TextEncoder(nn.Module):
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
+ super().__init__()
+ self.embedding = nn.Embedding(n_symbols, channels)
+
+ self.prepare_projection=LinearNorm(channels,channels // 2)
+ self.post_projection=LinearNorm(channels // 2,channels)
+ self.cfg = xLSTMBlockStackConfig(
+ mlstm_block=mLSTMBlockConfig(
+ mlstm=mLSTMLayerConfig(
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
+ )
+ ),
+ # slstm_block=sLSTMBlockConfig(
+ # slstm=sLSTMLayerConfig(
+ # backend="cuda",
+ # num_heads=4,
+ # conv1d_kernel_size=4,
+ # bias_init="powerlaw_blockdependent",
+ # ),
+ # feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
+ # ),
+ context_length=channels,
+ num_blocks=8,
+ embedding_dim=channels // 2,
+ # slstm_at=[1],
+
+ )
+
+
+
+ padding = (kernel_size - 1) // 2
+ self.cnn = nn.ModuleList()
+ for _ in range(depth):
+ self.cnn.append(nn.Sequential(
+
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
+ LayerNorm(channels),
+ actv,
+ nn.Dropout(0.2),
+ ))
+ # self.cnn = nn.Sequential(*self.cnn)
+
+
+ self.lstm = xLSTMBlockStack(self.cfg)
+ def forward(self, x, input_lengths, m):
+
+ x = self.embedding(x) # [B, T, emb]
+
+
+ x = x.transpose(1, 2) # [B, emb, T]
+ m = m.to(input_lengths.device).unsqueeze(1)
+ x.masked_fill_(m, 0.0)
+
+ for c in self.cnn:
+ x = c(x)
+ x.masked_fill_(m, 0.0)
+
+ x = x.transpose(1, 2) # [B, T, chn]
+
+
+ input_lengths = input_lengths.cpu().numpy()
+
+
+
+ x = self.prepare_projection(x)
+
+ # x = nn.utils.rnn.pack_padded_sequence(
+ # x, input_lengths, batch_first=True, enforce_sorted=False)
+
+ # self.lstm.flatten_parameters()
+ x = self.lstm(x)
+
+ x = self.post_projection(x)
+ # x, _ = nn.utils.rnn.pad_packed_sequence(
+ # x, batch_first=True)
+
+ x = x.transpose(-1, -2)
+# x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
+
+# x_pad[:, :, :x.shape[-1]] = x
+# x = x_pad.to(x.device)
+
+ x.masked_fill_(m, 0.0)
+
+ return x
+
+ def inference(self, x):
+ x = self.embedding(x)
+ x = x.transpose(1, 2)
+ x = self.cnn(x)
+ x = x.transpose(1, 2)
+ # self.lstm.flatten_parameters()
+ x = self.lstm(x)
+ return x
+
+ def length_to_mask(self, lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+
+
+class AdaIN1d(nn.Module):
+ def __init__(self, style_dim, num_features):
+ super().__init__()
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
+ self.fc = nn.Linear(style_dim, num_features*2)
+
+ def forward(self, x, s):
+ h = self.fc(s)
+
+ h = h.view(h.size(0), h.size(1), 1)
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
+ return (1 + gamma) * self.norm(x) + beta
+
+class UpSample1d(nn.Module):
+ def __init__(self, layer_type):
+ super().__init__()
+ self.layer_type = layer_type
+
+ def forward(self, x):
+ if self.layer_type == 'none':
+ return x
+ else:
+ return F.interpolate(x, scale_factor=2, mode='nearest')
+
+class AdainResBlk1d(nn.Module):
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
+ upsample='none', dropout_p=0.0):
+ super().__init__()
+ self.actv = actv
+ self.upsample_type = upsample
+ self.upsample = UpSample1d(upsample)
+ self.learned_sc = dim_in != dim_out
+ self._build_weights(dim_in, dim_out, style_dim)
+ self.dropout = nn.Dropout(dropout_p)
+
+ if upsample == 'none':
+ self.pool = nn.Identity()
+ else:
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
+
+
+ def _build_weights(self, dim_in, dim_out, style_dim):
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
+ self.norm1 = AdaIN1d(style_dim, dim_in)
+ self.norm2 = AdaIN1d(style_dim, dim_out)
+ if self.learned_sc:
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
+
+ def _shortcut(self, x):
+ x = self.upsample(x)
+ if self.learned_sc:
+ x = self.conv1x1(x)
+ return x
+
+ def _residual(self, x, s):
+ x = self.norm1(x, s)
+ x = self.actv(x)
+ x = self.pool(x)
+ x = self.conv1(self.dropout(x))
+ x = self.norm2(x, s)
+ x = self.actv(x)
+ x = self.conv2(self.dropout(x))
+ return x
+
+ def forward(self, x, s):
+ out = self._residual(x, s)
+ out = (out + self._shortcut(x)) / math.sqrt(2)
+ return out
+
+class AdaLayerNorm(nn.Module):
+ def __init__(self, style_dim, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.fc = nn.Linear(style_dim, channels*2)
+
+ def forward(self, x, s):
+ x = x.transpose(-1, -2)
+ x = x.transpose(1, -1)
+
+ h = self.fc(s)
+ h = h.view(h.size(0), h.size(1), 1)
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
+
+
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
+ x = (1 + gamma) * x + beta
+ return x.transpose(1, -1).transpose(-1, -2)
+
+# class ProsodyPredictor(nn.Module):
+
+# def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
+# super().__init__()
+
+# self.text_encoder = DurationEncoder(sty_dim=style_dim,
+# d_model=d_hid,
+# nlayers=nlayers,
+# dropout=dropout)
+
+# self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+# self.duration_proj = LinearNorm(d_hid, max_dur)
+
+# self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+# self.F0 = nn.ModuleList()
+# self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# self.N = nn.ModuleList()
+# self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+# self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+# self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+# self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+# self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+
+
+# def forward(self, texts, style, text_lengths, alignment, m):
+# d = self.text_encoder(texts, style, text_lengths, m)
+
+# batch_size = d.shape[0]
+# text_size = d.shape[1]
+
+# # predict duration
+# input_lengths = text_lengths.cpu().numpy()
+# x = nn.utils.rnn.pack_padded_sequence(
+# d, input_lengths, batch_first=True, enforce_sorted=False)
+
+# m = m.to(text_lengths.device).unsqueeze(1)
+
+# self.lstm.flatten_parameters()
+# x, _ = self.lstm(x)
+# x, _ = nn.utils.rnn.pad_packed_sequence(
+# x, batch_first=True)
+
+# x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
+
+# x_pad[:, :x.shape[1], :] = x
+# x = x_pad.to(x.device)
+
+# duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
+
+# en = (d.transpose(-1, -2) @ alignment)
+
+# return duration.squeeze(-1), en
+
+
+class ProsodyPredictor(nn.Module):
+
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
+ super().__init__()
+
+ self.cfg = xLSTMBlockStackConfig(
+ mlstm_block=mLSTMBlockConfig(
+ mlstm=mLSTMLayerConfig(
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
+ )
+ ),
+ context_length=d_hid,
+ num_blocks=8,
+ embedding_dim=d_hid + style_dim,
+
+
+ )
+
+ self.cfg_pred = xLSTMBlockStackConfig(
+ mlstm_block=mLSTMBlockConfig(
+ mlstm=mLSTMLayerConfig(
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
+ )
+ ),
+
+ context_length=4096,
+ num_blocks=8,
+ embedding_dim=d_hid + style_dim,
+
+ )
+
+
+ # self.shared = Hopfield(input_size=d_hid + style_dim,
+ # hidden_size=d_hid // 2,
+ # num_heads=32,
+ # # scaling=.75,
+ # add_zero_association=True,
+ # batch_first=True)
+
+ # if you want to use hopfield, just comment out the block above, then hash the "self.shared below"
+
+
+
+
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
+ d_model=d_hid,
+ nlayers=nlayers,
+ dropout=dropout)
+
+
+ self.lstm = xLSTMBlockStack(self.cfg)
+
+ self.prepare_projection = nn.Linear(d_hid + style_dim, d_hid)
+
+ self.duration_proj = LinearNorm(d_hid , max_dur)
+
+ self.shared = xLSTMBlockStack(self.cfg_pred)
+
+ # self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
+
+ self.F0 = nn.ModuleList()
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+ self.N = nn.ModuleList()
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
+
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
+
+
+ def forward(self, texts, style, text_lengths=None, alignment=None, m=None, f0=False):
+
+ if f0:
+ x, s = texts, style
+ # x = self.prepare_projection(x.transpose(-1, -2))
+ # x = self.shared(x)
+
+ x = self.shared(x.transpose(-1, -2))
+ x = self.prepare_projection(x)
+
+ F0 = x.transpose(-1, -2)
+ for block in self.F0:
+ F0 = block(F0, s)
+ F0 = self.F0_proj(F0)
+
+ N = x.transpose(-1, -2)
+ for block in self.N:
+ N = block(N, s)
+ N = self.N_proj(N)
+
+ return F0.squeeze(1), N.squeeze(1)
+
+ else:
+ # Problem is here
+ d = self.text_encoder(texts, style, text_lengths, m)
+
+ batch_size = d.shape[0]
+ text_size = d.shape[1]
+
+ # predict duration
+
+
+ input_lengths = text_lengths.cpu().numpy()
+
+
+ # x = nn.utils.rnn.pack_padded_sequence(
+ # d, input_lengths, batch_first=True, enforce_sorted=False)
+
+ x = d # this dude can handle variable seq len so no need for padding
+
+
+ m = m.to(text_lengths.device).unsqueeze(1)
+
+ # self.lstm.flatten_parameters()
+ x = self.lstm(x) # no longer using lstm
+ x = self.prepare_projection(x)
+
+
+ # x, _ = nn.utils.rnn.pad_packed_sequence(
+ # x, batch_first=True)
+
+ # x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
+
+ # x_pad[:, :x.shape[1], :] = x
+ # x = x_pad.to(x.device)
+
+ x = x.transpose(-1,-2)
+ x = x.permute(0,2,1)
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
+
+
+
+ en = (d.transpose(-1, -2) @ alignment)
+
+ return duration.squeeze(-1), en
+
+
+ def F0Ntrain(self, x, s):
+
+
+ # x = self.prepare_projection(x.transpose(-1, -2))
+ # x = self.shared(x)
+
+ ####
+ x = self.shared(x.transpose(-1, -2))
+ x = self.prepare_projection(x)
+
+
+
+ F0 = x.transpose(-1, -2)
+
+ for block in self.F0:
+ F0 = block(F0, s)
+ F0 = self.F0_proj(F0)
+
+ N = x.transpose(-1, -2)
+ for block in self.N:
+ N = block(N, s)
+ N = self.N_proj(N)
+
+ return F0.squeeze(1), N.squeeze(1)
+
+ def length_to_mask(self, lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+class DurationEncoder(nn.Module):
+
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
+ super().__init__()
+ self.lstms = nn.ModuleList()
+ for _ in range(nlayers):
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
+ d_model // 2,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=True,
+ dropout=dropout))
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
+
+
+ self.dropout = dropout
+ self.d_model = d_model
+ self.sty_dim = sty_dim
+
+ def forward(self, x, style, text_lengths, m):
+ masks = m.to(text_lengths.device)
+
+ x = x.permute(2, 0, 1)
+ s = style.expand(x.shape[0], x.shape[1], -1)
+ x = torch.cat([x, s], axis=-1)
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
+
+ x = x.transpose(0, 1)
+ input_lengths = text_lengths.cpu().numpy()
+ x = x.transpose(-1, -2)
+
+ for block in self.lstms:
+ if isinstance(block, AdaLayerNorm):
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
+ else:
+ x = x.transpose(-1, -2)
+ x = nn.utils.rnn.pack_padded_sequence(
+ x, input_lengths, batch_first=True, enforce_sorted=False)
+ block.flatten_parameters()
+ x, _ = block(x)
+ x, _ = nn.utils.rnn.pad_packed_sequence(
+ x, batch_first=True)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = x.transpose(-1, -2)
+
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
+
+ x_pad[:, :, :x.shape[-1]] = x
+ x = x_pad.to(x.device)
+
+ return x.transpose(-1, -2)
+
+ def inference(self, x, style):
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
+ style = style.expand(x.shape[0], x.shape[1], -1)
+ x = torch.cat([x, style], axis=-1)
+ src = self.pos_encoder(x)
+ output = self.transformer_encoder(src).transpose(0, 1)
+ return output
+
+ def length_to_mask(self, lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+ def inference(self, x, style):
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
+ style = style.expand(x.shape[0], x.shape[1], -1)
+ x = torch.cat([x, style], axis=-1)
+ src = self.pos_encoder(x)
+ output = self.transformer_encoder(src).transpose(0, 1)
+ return output
+
+ def length_to_mask(self, lengths):
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
+ return mask
+
+
+
+def load_F0_models(path):
+ # load F0 model
+
+ F0_model = JDCNet(num_class=1, seq_len=192)
+ params = torch.load(path, map_location='cpu')['net']
+ F0_model.load_state_dict(params)
+ _ = F0_model.train()
+
+ return F0_model
+
+
+def load_KotoDama_Prompter(path, cfg=None, model_ckpt="ku-nlp/deberta-v3-base-japanese"):
+
+ cfg = AutoConfig.from_pretrained(model_ckpt)
+ cfg.update({
+ "num_labels": 256
+ })
+
+ kotodama_prompt = KotoDama_Prompt.from_pretrained(path, config=cfg)
+
+ return kotodama_prompt
+
+
+def load_KotoDama_TextSampler(path, cfg=None, model_ckpt="line-corporation/line-distilbert-base-japanese"):
+
+ cfg = AutoConfig.from_pretrained(model_ckpt)
+ cfg.update({
+ "num_labels": 256
+ })
+
+ kotodama_sampler = KotoDama_Text.from_pretrained(path, config=cfg)
+
+ return kotodama_sampler
+
+
+
+# def reconstruction_head(path): # didn't make a lot of difference, disabling it for now until i find / train a better net
+
+# recon_model = DiffusionUpsampler(
+
+# net_t=UNetV0,
+# upsample_factor=2,
+# in_channels=1,
+# channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024],
+# factors=[1, 4, 4, 4, 2, 2, 2, 2, 2],
+# items=[1, 2, 2, 2, 2, 2, 2, 4, 4],
+# diffusion_t=VDiffusion,
+# sampler_t=VSampler,
+# )
+
+# checkpoint = torch.load(path, map_location='cpu')
+
+# new_state_dict = {}
+# for key, value in checkpoint['model_state_dict'].items():
+# new_key = key.replace('module.', '') # Remove 'module.' prefix
+# new_state_dict[new_key] = value
+
+# recon_model.load_state_dict(new_state_dict)
+# recon_model.eval()
+
+# recon_model = recon_model.to('cuda')
+
+# return recon_model
+
+
+
+def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
+ # load ASR model
+ def _load_config(path):
+ with open(path) as f:
+ config = yaml.safe_load(f)
+ model_config = config['model_params']
+ return model_config
+
+ def _load_model(model_config, model_path):
+ model = ASRCNN(**model_config)
+ params = torch.load(model_path, map_location='cpu')['model']
+ model.load_state_dict(params)
+ return model
+
+ asr_model_config = _load_config(ASR_MODEL_CONFIG)
+ asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
+ _ = asr_model.train()
+
+ return asr_model
+
+def build_model(args, text_aligner, pitch_extractor, bert, KotoDama_Prompt, KotoDama_Text):
+ assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
+
+ if args.decoder.type == "istftnet":
+ from Modules.istftnet import Decoder
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
+ upsample_rates = args.decoder.upsample_rates,
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
+ else:
+ from Modules.hifigan import Decoder
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
+ upsample_rates = args.decoder.upsample_rates,
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
+
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
+
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
+
+ style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
+ predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
+
+ # define diffusion model
+ if args.multispeaker:
+ transformer = StyleTransformer1d(channels=args.style_dim*2,
+ context_embedding_features=bert.config.hidden_size,
+ context_features=args.style_dim*2,
+ **args.diffusion.transformer)
+ else:
+ transformer = Transformer1d(channels=args.style_dim*2,
+ context_embedding_features=bert.config.hidden_size,
+ **args.diffusion.transformer)
+
+ diffusion = AudioDiffusionConditional(
+ in_channels=1,
+ embedding_max_length=bert.config.max_position_embeddings,
+ embedding_features=bert.config.hidden_size,
+ embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
+ channels=args.style_dim*2,
+ context_features=args.style_dim*2,
+ )
+
+ diffusion.diffusion = KDiffusion(
+ net=diffusion.unet,
+ sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
+ sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
+ dynamic_threshold=0.0
+ )
+ diffusion.diffusion.net = transformer
+ diffusion.unet = transformer
+
+
+ nets = Munch(
+
+ bert=bert,
+ bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
+
+ predictor=predictor,
+ decoder=decoder,
+ text_encoder=text_encoder,
+
+ predictor_encoder=predictor_encoder,
+ style_encoder=style_encoder,
+ diffusion=diffusion,
+
+ text_aligner = text_aligner,
+ pitch_extractor = pitch_extractor,
+
+ mpd = MultiPeriodDiscriminator(),
+ msd = MultiResSpecDiscriminator(),
+
+ # slm discriminator head
+ wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
+
+ KotoDama_Prompt = KotoDama_Prompt,
+ KotoDama_Text = KotoDama_Text,
+
+ # recon_diff = recon_diff,
+
+ )
+
+ return nets
+
+
+def load_checkpoint(model, optimizer, path, load_only_params=False, ignore_modules=[]):
+ state = torch.load(path, map_location='cpu')
+ params = state['net']
+ print('loading the ckpt using the correct function.')
+
+ for key in model:
+ if key in params and key not in ignore_modules:
+ try:
+ model[key].load_state_dict(params[key], strict=True)
+ except:
+ from collections import OrderedDict
+ state_dict = params[key]
+ new_state_dict = OrderedDict()
+ print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict key length: {len(state_dict.keys())}')
+ for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
+ new_state_dict[k_m] = v_c
+ model[key].load_state_dict(new_state_dict, strict=True)
+ print('%s loaded' % key)
+
+ if not load_only_params:
+ epoch = state["epoch"]
+ iters = state["iters"]
+ optimizer.load_state_dict(state["optimizer"])
+ else:
+ epoch = 0
+ iters = 0
+
+ return model, optimizer, epoch, iters
diff --git a/optimizers.py b/optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e192bc1f2d2c21f63c9c7d3bf6195715c8713278
--- /dev/null
+++ b/optimizers.py
@@ -0,0 +1,86 @@
+# coding:utf-8
+import os, sys
+import os.path as osp
+import numpy as np
+import torch
+from torch import nn
+from torch.optim import Optimizer
+from functools import reduce
+from torch.optim import AdamW
+
+
+class MultiOptimizer:
+ def __init__(self, optimizers={}, schedulers={}):
+ self.optimizers = optimizers
+ self.schedulers = schedulers
+ self.keys = list(optimizers.keys())
+ self.param_groups = reduce(
+ lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
+ )
+
+ def state_dict(self):
+ state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
+ return state_dicts
+
+ def load_state_dict(self, state_dict):
+ for key, val in state_dict:
+ try:
+ self.optimizers[key].load_state_dict(val)
+ except:
+ print("Unloaded %s" % key)
+
+ def step(self, key=None, scaler=None):
+ keys = [key] if key is not None else self.keys
+ _ = [self._step(key, scaler) for key in keys]
+
+ def _step(self, key, scaler=None):
+ if scaler is not None:
+ scaler.step(self.optimizers[key])
+ scaler.update()
+ else:
+ self.optimizers[key].step()
+
+ def zero_grad(self, key=None):
+ if key is not None:
+ self.optimizers[key].zero_grad()
+ else:
+ _ = [self.optimizers[key].zero_grad() for key in self.keys]
+
+ def scheduler(self, *args, key=None):
+ if key is not None:
+ self.schedulers[key].step(*args)
+ else:
+ _ = [self.schedulers[key].step(*args) for key in self.keys]
+
+
+def define_scheduler(optimizer, params):
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=params.get("max_lr", 2e-4),
+ epochs=params.get("epochs", 200),
+ steps_per_epoch=params.get("steps_per_epoch", 1000),
+ pct_start=params.get("pct_start", 0.0),
+ div_factor=1,
+ final_div_factor=1,
+ )
+
+ return scheduler
+
+
+def build_optimizer(parameters_dict, scheduler_params_dict, lr):
+ optim = dict(
+ [
+ (key, AdamW(params, lr=lr, weight_decay=1e-4, betas=(0.0, 0.99), eps=1e-9))
+ for key, params in parameters_dict.items()
+ ]
+ )
+
+ schedulers = dict(
+ [
+ (key, define_scheduler(opt, scheduler_params_dict[key]))
+ for key, opt in optim.items()
+ ]
+ )
+
+ multi_optim = MultiOptimizer(optim, schedulers)
+ return multi_optim
diff --git a/reference_sample_wavs/01001240.ogg b/reference_sample_wavs/01001240.ogg
new file mode 100644
index 0000000000000000000000000000000000000000..cb1cf76b1e086d27621dbc6fd01f7cb385ac7246
Binary files /dev/null and b/reference_sample_wavs/01001240.ogg differ
diff --git a/reference_sample_wavs/01008270.wav b/reference_sample_wavs/01008270.wav
new file mode 100644
index 0000000000000000000000000000000000000000..21ee524eae2607c1676394f9b6a364fe7300997e
--- /dev/null
+++ b/reference_sample_wavs/01008270.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a306920452276a8b801454ba5d540c7f3c28a3fc0d5ce01bf4a3f679e0f42c3
+size 1082540
diff --git a/reference_sample_wavs/kaede_san.wav b/reference_sample_wavs/kaede_san.wav
new file mode 100644
index 0000000000000000000000000000000000000000..3e047dfee2e40f934574d301cdf15733a1c07ca3
--- /dev/null
+++ b/reference_sample_wavs/kaede_san.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:376737a52bf7f67ba6035597bae5ad87b5220d005bad78318d3f8062eb9ff692
+size 1812558
diff --git a/reference_sample_wavs/riamu_zeroshot_01.wav b/reference_sample_wavs/riamu_zeroshot_01.wav
new file mode 100644
index 0000000000000000000000000000000000000000..43a753e1a49fac7fbcfae3eb8ae9206c7c95f356
Binary files /dev/null and b/reference_sample_wavs/riamu_zeroshot_01.wav differ
diff --git a/reference_sample_wavs/riamu_zeroshot_02.wav b/reference_sample_wavs/riamu_zeroshot_02.wav
new file mode 100644
index 0000000000000000000000000000000000000000..007965ae590c3d84da413dea053c870abebec125
--- /dev/null
+++ b/reference_sample_wavs/riamu_zeroshot_02.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fa74683a6ac7dca963e3ae4b10f5984902683bd55d2806542b8821a9d07beaa2
+size 1427500
diff --git a/reference_sample_wavs/sample_ref01.wav b/reference_sample_wavs/sample_ref01.wav
new file mode 100644
index 0000000000000000000000000000000000000000..101ff58e67e315e03be66293d061eb12e3863569
--- /dev/null
+++ b/reference_sample_wavs/sample_ref01.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a4241b264d96819f6d4290c23861401f7b116bbb9fa9aace8b65add01b0d812b
+size 1644002
diff --git a/reference_sample_wavs/sample_ref02.wav b/reference_sample_wavs/sample_ref02.wav
new file mode 100644
index 0000000000000000000000000000000000000000..e181f5f3ad40eb49a87e31641bd998e227152391
--- /dev/null
+++ b/reference_sample_wavs/sample_ref02.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a38f2f13d1035d0148d965410cde080c6371c9460e17f147ef130aacb4551b1c
+size 1803998
diff --git a/reference_sample_wavs/shiki_fine05.wav b/reference_sample_wavs/shiki_fine05.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ee2897bd850040d2505a9f885574114aa369e395
--- /dev/null
+++ b/reference_sample_wavs/shiki_fine05.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd063aba7ad59b2bbfb5ed57d164dbbf75c70b91b163c851ca334661911c16c5
+size 2123200
diff --git a/reference_sample_wavs/syuukovoice_200918_3_01.wav b/reference_sample_wavs/syuukovoice_200918_3_01.wav
new file mode 100644
index 0000000000000000000000000000000000000000..048162b44269ddd918d0ff5b39ca93aaec87f1d0
--- /dev/null
+++ b/reference_sample_wavs/syuukovoice_200918_3_01.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7929dc92fdfa61ba580a20d95a677d1f6fe8de10edeae6778d664075e43aeb02
+size 1979500
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1e29d1b75ee20ecdd5814e04d57c5e9e3c896bff
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+SoundFile
+munch
+pydub
+pyyaml
+torchaudio
+librosa
+pydub
+pyyaml
+nltk
+matplotlib
+unidic-lite
+xlstm
+fugashi
+cutlet
+accelerate
+transformers
+einops
+einops-exts
+tqdm
+typing
+typing-extensions
+a_unet
+git+https://github.com/resemble-ai/monotonic_align.git
\ No newline at end of file
diff --git a/text_utils.py b/text_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c4bb9929c1165bb3928a6224f4040dd4636ec90
--- /dev/null
+++ b/text_utils.py
@@ -0,0 +1,28 @@
+# IPA Phonemizer: https://github.com/bootphon/phonemizer
+
+_pad = "$"
+_punctuation = ';:,.!?¡¿—…"«»“” '
+_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
+
+# Export all symbols:
+symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
+
+dicts = {}
+for i in range(len((symbols))):
+ dicts[symbols[i]] = i
+
+
+class TextCleaner:
+ def __init__(self, dummy=None):
+ self.word_index_dictionary = dicts
+ print(len(dicts))
+
+ def __call__(self, text):
+ indexes = []
+ for char in text:
+ try:
+ indexes.append(self.word_index_dictionary[char])
+ except KeyError:
+ print(text)
+ return indexes
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..369acc3ccda122b497faea59ec6426f33af6ed55
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,89 @@
+from monotonic_align import maximum_path
+from monotonic_align import mask_from_lens
+from monotonic_align.core import maximum_path_c
+import numpy as np
+import torch
+import copy
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+import librosa
+import matplotlib.pyplot as plt
+from munch import Munch
+
+
+def maximum_path(neg_cent, mask):
+ """Cython optimized version.
+ neg_cent: [b, t_t, t_s]
+ mask: [b, t_t, t_s]
+ """
+ device = neg_cent.device
+ dtype = neg_cent.dtype
+ neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32))
+ path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32))
+
+ t_t_max = np.ascontiguousarray(
+ mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
+ )
+ t_s_max = np.ascontiguousarray(
+ mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
+ )
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
+
+
+def get_data_path_list(train_path=None, val_path=None):
+ if train_path is None:
+ train_path = "Data/train_list.txt"
+ if val_path is None:
+ val_path = "Data/val_list.txt"
+
+ with open(train_path, "r", encoding="utf-8", errors="ignore") as f:
+ train_list = f.readlines()
+ with open(val_path, "r", encoding="utf-8", errors="ignore") as f:
+ val_list = f.readlines()
+
+ return train_list, val_list
+
+
+def length_to_mask(lengths):
+ mask = (
+ torch.arange(lengths.max())
+ .unsqueeze(0)
+ .expand(lengths.shape[0], -1)
+ .type_as(lengths)
+ )
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
+ return mask
+
+
+# for norm consistency loss
+def log_norm(x, mean=-4, std=4, dim=2):
+ """
+ normalized log mel -> mel -> norm -> log(norm)
+ """
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
+ return x
+
+
+def get_image(arrs):
+ plt.switch_backend("agg")
+ fig = plt.figure()
+ ax = plt.gca()
+ ax.imshow(arrs)
+
+ return fig
+
+
+def recursive_munch(d):
+ if isinstance(d, dict):
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
+ elif isinstance(d, list):
+ return [recursive_munch(v) for v in d]
+ else:
+ return d
+
+
+def log_print(message, logger):
+ logger.info(message)
+ print(message)