File size: 5,021 Bytes
12da6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import random
from pathlib import Path

import numpy as np
import textgrid
from scipy.io import wavfile

from .config import FLAGS, AcousticInput, DurationInput


def load_phonemes_set():
    S = FLAGS.special_phonemes + FLAGS._normal_phonemes
    return S


def pad_seq(s, maxlen, value=0):
    assert maxlen >= len(s)
    return tuple(s) + (value,) * (maxlen - len(s))


def is_in_word(phone, word):
    def time_in_word(time, word):
        return (word.minTime - 1e-3) < time and (word.maxTime + 1e-3) > time

    return time_in_word(phone.minTime, word) and time_in_word(phone.maxTime, word)


def load_textgrid(fn: Path):
    """load textgrid file"""
    tg = textgrid.TextGrid.fromFile(str(fn.resolve()))
    data = []
    words = list(tg[0])
    widx = 0
    assert tg[1][0].minTime == 0, "The first phoneme has to start at time 0"
    for p in tg[1]:
        if not p in words[widx]:
            widx = widx + 1
            if len(words[widx - 1].mark) > 0:
                data.append((FLAGS.special_phonemes[FLAGS.word_end_index], 0.0))
            if widx >= len(words):
                break
            assert p in words[widx], "mismatched word vs phoneme"
        mark = p.mark.strip().lower()
        if len(mark) == 0:
            mark = "sil"
        data.append((mark, p.duration()))
    return data


def textgrid_data_loader(data_dir: Path, seq_len: int, batch_size: int, mode: str):
    """load all textgrid files in the directory"""
    tg_files = sorted(data_dir.glob("*.TextGrid"))
    random.Random(42).shuffle(tg_files)
    L = len(tg_files) * 95 // 100
    assert mode in ["train", "val"]
    phonemes = load_phonemes_set()
    if mode == "train":
        tg_files = tg_files[:L]
    if mode == "val":
        tg_files = tg_files[L:]

    data = []
    for fn in tg_files:
        ps, ds = zip(*load_textgrid(fn))
        ps = [phonemes.index(p) for p in ps]
        l = len(ps)
        ps = pad_seq(ps, seq_len, 0)
        ds = pad_seq(ds, seq_len, 0)
        data.append((ps, ds, l))

    batch = []
    while True:
        random.shuffle(data)
        for e in data:
            batch.append(e)
            if len(batch) == batch_size:
                ps, ds, lengths = zip(*batch)
                ps = np.array(ps, dtype=np.int32)
                ds = np.array(ds, dtype=np.float32)
                lengths = np.array(lengths, dtype=np.int32)
                yield DurationInput(ps, lengths, ds)
                batch = []


def load_textgrid_wav(
    data_dir: Path, token_seq_len: int, batch_size, pad_wav_len, mode: str
):
    """load wav and textgrid files to memory."""
    tg_files = sorted(data_dir.glob("*.TextGrid"))
    random.Random(42).shuffle(tg_files)
    L = len(tg_files) * 95 // 100
    assert mode in ["train", "val", "gta"]
    phonemes = load_phonemes_set()
    if mode == "gta":
        tg_files = tg_files  # all files
    elif mode == "train":
        tg_files = tg_files[:L]
    elif mode == "val":
        tg_files = tg_files[L:]

    data = []
    for fn in tg_files:
        ps, ds = zip(*load_textgrid(fn))
        ps = [phonemes.index(p) for p in ps]
        l = len(ps)
        ps = pad_seq(ps, token_seq_len, 0)
        ds = pad_seq(ds, token_seq_len, 0)

        wav_file = data_dir / f"{fn.stem}.wav"
        sr, y = wavfile.read(wav_file)
        y = np.copy(y)
        start_time = 0
        for i, (phone_idx, duration) in enumerate(zip(ps, ds)):
            l = int(start_time * sr)
            end_time = start_time + duration
            r = int(end_time * sr)
            if i == len(ps) - 1:
                r = len(y)
            if phone_idx < len(FLAGS.special_phonemes):
                y[l:r] = 0
            start_time = end_time

        if len(y) > pad_wav_len:
            y = y[:pad_wav_len]

        # # normalize to match hifigan preprocessing
        # y = y.astype(np.float32)
        # y = y / np.max(np.abs(y))
        # y = y * 0.95
        # y = y * (2 ** 15)
        # y = y.astype(np.int16)

        wav_length = len(y)
        y = np.pad(y, (0, pad_wav_len - len(y)))
        data.append((fn.stem, ps, ds, l, y, wav_length))

    batch = []
    while True:
        random.shuffle(data)
        for idx, e in enumerate(data):
            batch.append(e)
            if len(batch) == batch_size or (mode == "gta" and idx == len(data) - 1):
                names, ps, ds, lengths, wavs, wav_lengths = zip(*batch)
                ps = np.array(ps, dtype=np.int32)
                ds = np.array(ds, dtype=np.float32)
                lengths = np.array(lengths, dtype=np.int32)
                wavs = np.array(wavs, dtype=np.int16)
                wav_lengths = np.array(wav_lengths, dtype=np.int32)
                if mode == "gta":
                    yield names, AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
                else:
                    yield AcousticInput(ps, lengths, ds, wavs, wav_lengths, None)
                batch = []
        if mode == "gta":
            assert len(batch) == 0
            break