File size: 5,021 Bytes
12da6cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import random
from pathlib import Path
import numpy as np
import textgrid
from scipy.io import wavfile
from .config import FLAGS, AcousticInput, DurationInput
def load_phonemes_set():
S = FLAGS.special_phonemes + FLAGS._normal_phonemes
return S
def pad_seq(s, maxlen, value=0):
assert maxlen >= len(s)
return tuple(s) + (value,) * (maxlen - len(s))
def is_in_word(phone, word):
def time_in_word(time, word):
return (word.minTime - 1e-3) < time and (word.maxTime + 1e-3) > time
return time_in_word(phone.minTime, word) and time_in_word(phone.maxTime, word)
def load_textgrid(fn: Path):
"""load textgrid file"""
tg = textgrid.TextGrid.fromFile(str(fn.resolve()))
data = []
words = list(tg[0])
widx = 0
assert tg[1][0].minTime == 0, "The first phoneme has to start at time 0"
for p in tg[1]:
if not p in words[widx]:
widx = widx + 1
if len(words[widx - 1].mark) > 0:
data.append((FLAGS.special_phonemes[FLAGS.word_end_index], 0.0))
if widx >= len(words):
break
assert p in words[widx], "mismatched word vs phoneme"
mark = p.mark.strip().lower()
if len(mark) == 0:
mark = "sil"
data.append((mark, p.duration()))
return data
def textgrid_data_loader(data_dir: Path, seq_len: int, batch_size: int, mode: str):
"""load all textgrid files in the directory"""
tg_files = sorted(data_dir.glob("*.TextGrid"))
random.Random(42).shuffle(tg_files)
L = len(tg_files) * 95 // 100
assert mode in ["train", "val"]
phonemes = load_phonemes_set()
if mode == "train":
tg_files = tg_files[:L]
if mode == "val":
tg_files = tg_files[L:]
data = []
for fn in tg_files:
ps, ds = zip(*load_textgrid(fn))
ps = [phonemes.index(p) for p in ps]
l = len(ps)
ps = pad_seq(ps, seq_len, 0)
ds = pad_seq(ds, seq_len, 0)
data.append((ps, ds, l))
batch = []
while True:
random.shuffle(data)
for e in data:
batch.append(e)
if len(batch) == batch_size:
ps, ds, lengths = zip(*batch)
ps = np.array(ps, dtype=np.int32)
ds = np.array(ds, dtype=np.float32)
lengths = np.array(lengths, dtype=np.int32)
yield DurationInput(ps, lengths, ds)
batch = []
def load_textgrid_wav(
data_dir: Path, token_seq_len: int, batch_size, pad_wav_len, mode: str
):
"""load wav and textgrid files to memory."""
tg_files = sorted(data_dir.glob("*.TextGrid"))
random.Random(42).shuffle(tg_files)
L = len(tg_files) * 95 // 100
assert mode in ["train", "val", "gta"]
phonemes = load_phonemes_set()
if mode == "gta":
tg_files = tg_files # all files
elif mode == "train":
tg_files = tg_files[:L]
elif mode == "val":
tg_files = tg_files[L:]
data = []
for fn in tg_files:
ps, ds = zip(*load_textgrid(fn))
ps = [phonemes.index(p) for p in ps]
l = len(ps)
ps = pad_seq(ps, token_seq_len, 0)
ds = pad_seq(ds, token_seq_len, 0)
wav_file = data_dir / f"{fn.stem}.wav"
sr, y = wavfile.read(wav_file)
y = np.copy(y)
start_time = 0
for i, (phone_idx, duration) in enumerate(zip(ps, ds)):
l = int(start_time * sr)
end_time = start_time + duration
r = int(end_time * sr)
if i == len(ps) - 1:
r = len(y)
if phone_idx < len(FLAGS.special_phonemes):
y[l:r] = 0
start_time = end_time
if len(y) > pad_wav_len:
y = y[:pad_wav_len]
# # normalize to match hifigan preprocessing
# y = y.astype(np.float32)
# y = y / np.max(np.abs(y))
# y = y * 0.95
# y = y * (2 ** 15)
# y = y.astype(np.int16)
wav_length = len(y)
y = np.pad(y, (0, pad_wav_len - len(y)))
data.append((fn.stem, ps, ds, l, y, wav_length))
batch = []
while True:
random.shuffle(data)
for idx, e in enumerate(data):
batch.append(e)
if len(batch) == batch_size or (mode == "gta" and idx == len(data) - 1):
names, ps, ds, lengths, wavs, wav_lengths = zip(*batch)
ps = np.array(ps, dtype=np.int32)
ds = np.array(ds, dtype=np.float32)
lengths = np.array(lengths, dtype=np.int32)
wavs = np.array(wavs, dtype=np.int16)
wav_lengths = np.array(wav_lengths, dtype=np.int32)
if mode == "gta":
yield names, AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
else:
yield AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
batch = []
if mode == "gta":
assert len(batch) == 0
break
|