update
Browse files- vietTTS/hifigan/mel2wave.py +7 -7
- vietTTS/nat/text2mel.py +11 -8
vietTTS/hifigan/mel2wave.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
import json
|
2 |
-
import os
|
3 |
import pickle
|
4 |
|
5 |
import haiku as hk
|
6 |
import jax
|
7 |
import jax.numpy as jnp
|
8 |
-
import numpy as np
|
9 |
|
10 |
from .config import FLAGS
|
11 |
from .model import Generator
|
@@ -17,9 +15,11 @@ class AttrDict(dict):
|
|
17 |
self.__dict__ = self
|
18 |
|
19 |
|
20 |
-
def mel2wave(
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
with open(config_file) as f:
|
24 |
data = f.read()
|
25 |
json_config = json.loads(data)
|
@@ -32,10 +32,10 @@ def mel2wave(mel):
|
|
32 |
|
33 |
rng = next(hk.PRNGSequence(42))
|
34 |
|
35 |
-
with open(
|
36 |
params = pickle.load(f)
|
37 |
aux = {}
|
38 |
wav, aux = forward.apply(params, aux, rng, mel)
|
39 |
wav = jnp.squeeze(wav)
|
40 |
audio = jax.device_get(wav)
|
41 |
-
return
|
|
|
1 |
import json
|
|
|
2 |
import pickle
|
3 |
|
4 |
import haiku as hk
|
5 |
import jax
|
6 |
import jax.numpy as jnp
|
|
|
7 |
|
8 |
from .config import FLAGS
|
9 |
from .model import Generator
|
|
|
15 |
self.__dict__ = self
|
16 |
|
17 |
|
18 |
+
def mel2wave(
|
19 |
+
mel,
|
20 |
+
config_file="assets/hifigan/config.json",
|
21 |
+
ckpt_file=FLAGS.ckpt_dir / "hk_hifi.pickle",
|
22 |
+
):
|
23 |
with open(config_file) as f:
|
24 |
data = f.read()
|
25 |
json_config = json.loads(data)
|
|
|
32 |
|
33 |
rng = next(hk.PRNGSequence(42))
|
34 |
|
35 |
+
with open(ckpt_file, "rb") as f:
|
36 |
params = pickle.load(f)
|
37 |
aux = {}
|
38 |
wav, aux = forward.apply(params, aux, rng, mel)
|
39 |
wav = jnp.squeeze(wav)
|
40 |
audio = jax.device_get(wav)
|
41 |
+
return
|
vietTTS/nat/text2mel.py
CHANGED
@@ -19,12 +19,12 @@ def load_lexicon(fn):
|
|
19 |
return dict(lines)
|
20 |
|
21 |
|
22 |
-
def predict_duration(tokens):
|
23 |
def fwd_(x):
|
24 |
return DurationModel(is_training=False)(x)
|
25 |
|
26 |
forward_fn = jax.jit(hk.transform_with_state(fwd_).apply)
|
27 |
-
with open(
|
28 |
dic = pickle.load(f)
|
29 |
x = DurationInput(
|
30 |
np.array(tokens, dtype=np.int32)[None, :],
|
@@ -58,8 +58,7 @@ def text2tokens(text, lexicon_fn):
|
|
58 |
return tokens
|
59 |
|
60 |
|
61 |
-
def predict_mel(tokens, durations):
|
62 |
-
ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
|
63 |
with open(ckpt_fn, "rb") as f:
|
64 |
dic = pickle.load(f)
|
65 |
last_step, params, aux, rng, optim_state = (
|
@@ -83,10 +82,14 @@ def predict_mel(tokens, durations):
|
|
83 |
|
84 |
|
85 |
def text2mel(
|
86 |
-
text: str,
|
|
|
|
|
|
|
|
|
87 |
):
|
88 |
tokens = text2tokens(text, lexicon_fn)
|
89 |
-
durations = predict_duration(tokens)
|
90 |
durations = jnp.where(
|
91 |
np.array(tokens)[None, :] == FLAGS.sil_index,
|
92 |
jnp.clip(durations, a_min=silence_duration, a_max=None),
|
@@ -95,7 +98,7 @@ def text2mel(
|
|
95 |
durations = jnp.where(
|
96 |
np.array(tokens)[None, :] == FLAGS.word_end_index, 0.0, durations
|
97 |
)
|
98 |
-
mels = predict_mel(tokens, durations)
|
99 |
if tokens[-1] == FLAGS.sil_index:
|
100 |
end_silence = durations[0, -1].item()
|
101 |
silence_frame = int(end_silence * FLAGS.sample_rate / (FLAGS.n_fft // 4))
|
@@ -114,4 +117,4 @@ if __name__ == "__main__":
|
|
114 |
plt.savefig(str(args.output))
|
115 |
plt.close()
|
116 |
mel = jax.device_get(mel)
|
117 |
-
mel.tofile("clip.mel")
|
|
|
19 |
return dict(lines)
|
20 |
|
21 |
|
22 |
+
def predict_duration(tokens, ckpt_file):
|
23 |
def fwd_(x):
|
24 |
return DurationModel(is_training=False)(x)
|
25 |
|
26 |
forward_fn = jax.jit(hk.transform_with_state(fwd_).apply)
|
27 |
+
with open(ckpt_file, "rb") as f:
|
28 |
dic = pickle.load(f)
|
29 |
x = DurationInput(
|
30 |
np.array(tokens, dtype=np.int32)[None, :],
|
|
|
58 |
return tokens
|
59 |
|
60 |
|
61 |
+
def predict_mel(tokens, durations, ckpt_fn):
|
|
|
62 |
with open(ckpt_fn, "rb") as f:
|
63 |
dic = pickle.load(f)
|
64 |
last_step, params, aux, rng, optim_state = (
|
|
|
82 |
|
83 |
|
84 |
def text2mel(
|
85 |
+
text: str,
|
86 |
+
lexicon_fn=FLAGS.data_dir / "lexicon.txt",
|
87 |
+
silence_duration: float = -1.0,
|
88 |
+
acoustic_ckpt=FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle",
|
89 |
+
duration_ckpt=FLAGS.ckpt_dir / "duration_latest_ckpt.pickle",
|
90 |
):
|
91 |
tokens = text2tokens(text, lexicon_fn)
|
92 |
+
durations = predict_duration(tokens, duration_ckpt)
|
93 |
durations = jnp.where(
|
94 |
np.array(tokens)[None, :] == FLAGS.sil_index,
|
95 |
jnp.clip(durations, a_min=silence_duration, a_max=None),
|
|
|
98 |
durations = jnp.where(
|
99 |
np.array(tokens)[None, :] == FLAGS.word_end_index, 0.0, durations
|
100 |
)
|
101 |
+
mels = predict_mel(tokens, durations, acoustic_ckpt)
|
102 |
if tokens[-1] == FLAGS.sil_index:
|
103 |
end_silence = durations[0, -1].item()
|
104 |
silence_frame = int(end_silence * FLAGS.sample_rate / (FLAGS.n_fft // 4))
|
|
|
117 |
plt.savefig(str(args.output))
|
118 |
plt.close()
|
119 |
mel = jax.device_get(mel)
|
120 |
+
mel.tofile("clip.mel")
|