File size: 3,712 Bytes
12da6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import pickle
from argparse import ArgumentParser
from pathlib import Path

import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from .config import FLAGS, DurationInput
from .data_loader import load_phonemes_set
from .model import AcousticModel, DurationModel


def load_lexicon(fn):
    lines = open(fn, "r").readlines()
    lines = [l.lower().strip().split("\t") for l in lines]
    return dict(lines)


def predict_duration(tokens):
    def fwd_(x):
        return DurationModel(is_training=False)(x)

    forward_fn = jax.jit(hk.transform_with_state(fwd_).apply)
    with open(FLAGS.ckpt_dir / "duration_latest_ckpt.pickle", "rb") as f:
        dic = pickle.load(f)
    x = DurationInput(
        np.array(tokens, dtype=np.int32)[None, :],
        np.array([len(tokens)], dtype=np.int32),
        None,
    )
    return forward_fn(dic["params"], dic["aux"], dic["rng"], x)[0]


def text2tokens(text, lexicon_fn):
    phonemes = load_phonemes_set()
    lexicon = load_lexicon(lexicon_fn)

    words = text.strip().lower().split()
    tokens = [FLAGS.sil_index]
    for word in words:
        if word in FLAGS.special_phonemes:
            tokens.append(phonemes.index(word))
        elif word in lexicon:
            p = lexicon[word]
            p = p.split()
            p = [phonemes.index(pp) for pp in p]
            tokens.extend(p)
            tokens.append(FLAGS.word_end_index)
        else:
            for p in word:
                if p in phonemes:
                    tokens.append(phonemes.index(p))
            tokens.append(FLAGS.word_end_index)
    tokens.append(FLAGS.sil_index)  # silence
    return tokens


def predict_mel(tokens, durations):
    ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
    with open(ckpt_fn, "rb") as f:
        dic = pickle.load(f)
        last_step, params, aux, rng, optim_state = (
            dic["step"],
            dic["params"],
            dic["aux"],
            dic["rng"],
            dic["optim_state"],
        )

    @hk.transform_with_state
    def forward(tokens, durations, n_frames):
        net = AcousticModel(is_training=False)
        return net.inference(tokens, durations, n_frames)

    durations = durations * FLAGS.sample_rate / (FLAGS.n_fft // 4)
    n_frames = int(jnp.sum(durations).item())
    predict_fn = jax.jit(forward.apply, static_argnums=[5])
    tokens = np.array(tokens, dtype=np.int32)[None, :]
    return predict_fn(params, aux, rng, tokens, durations, n_frames)[0]


def text2mel(
    text: str, lexicon_fn=FLAGS.data_dir / "lexicon.txt", silence_duration: float = -1.0
):
    tokens = text2tokens(text, lexicon_fn)
    durations = predict_duration(tokens)
    durations = jnp.where(
        np.array(tokens)[None, :] == FLAGS.sil_index,
        jnp.clip(durations, a_min=silence_duration, a_max=None),
        durations,
    )
    durations = jnp.where(
        np.array(tokens)[None, :] == FLAGS.word_end_index, 0.0, durations
    )
    mels = predict_mel(tokens, durations)
    if tokens[-1] == FLAGS.sil_index:
        end_silence = durations[0, -1].item()
        silence_frame = int(end_silence * FLAGS.sample_rate / (FLAGS.n_fft // 4))
        mels = mels[:, : (mels.shape[1] - silence_frame)]
    return mels


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--text", type=str, required=True)
    parser.add_argument("--output", type=Path, required=True)
    args = parser.parse_args()
    mel = text2mel(args.text)
    plt.figure(figsize=(10, 5))
    plt.imshow(mel[0].T, origin="lower", aspect="auto")
    plt.savefig(str(args.output))
    plt.close()
    mel = jax.device_get(mel)
    mel.tofile("clip.mel")