tobiccino commited on
Commit
725d577
1 Parent(s): 5386471
vietTTS/hifigan/mel2wave.py CHANGED
@@ -1,11 +1,9 @@
1
  import json
2
- import os
3
  import pickle
4
 
5
  import haiku as hk
6
  import jax
7
  import jax.numpy as jnp
8
- import numpy as np
9
 
10
  from .config import FLAGS
11
  from .model import Generator
@@ -17,9 +15,11 @@ class AttrDict(dict):
17
  self.__dict__ = self
18
 
19
 
20
- def mel2wave(mel):
21
- config_file = "assets/hifigan/config.json"
22
- MAX_WAV_VALUE = 32768.0
 
 
23
  with open(config_file) as f:
24
  data = f.read()
25
  json_config = json.loads(data)
@@ -32,10 +32,10 @@ def mel2wave(mel):
32
 
33
  rng = next(hk.PRNGSequence(42))
34
 
35
- with open(FLAGS.ckpt_dir / "hk_hifi.pickle", "rb") as f:
36
  params = pickle.load(f)
37
  aux = {}
38
  wav, aux = forward.apply(params, aux, rng, mel)
39
  wav = jnp.squeeze(wav)
40
  audio = jax.device_get(wav)
41
- return audio
 
1
  import json
 
2
  import pickle
3
 
4
  import haiku as hk
5
  import jax
6
  import jax.numpy as jnp
 
7
 
8
  from .config import FLAGS
9
  from .model import Generator
 
15
  self.__dict__ = self
16
 
17
 
18
+ def mel2wave(
19
+ mel,
20
+ config_file="assets/hifigan/config.json",
21
+ ckpt_file=FLAGS.ckpt_dir / "hk_hifi.pickle",
22
+ ):
23
  with open(config_file) as f:
24
  data = f.read()
25
  json_config = json.loads(data)
 
32
 
33
  rng = next(hk.PRNGSequence(42))
34
 
35
+ with open(ckpt_file, "rb") as f:
36
  params = pickle.load(f)
37
  aux = {}
38
  wav, aux = forward.apply(params, aux, rng, mel)
39
  wav = jnp.squeeze(wav)
40
  audio = jax.device_get(wav)
41
+ return
vietTTS/nat/text2mel.py CHANGED
@@ -19,12 +19,12 @@ def load_lexicon(fn):
19
  return dict(lines)
20
 
21
 
22
- def predict_duration(tokens):
23
  def fwd_(x):
24
  return DurationModel(is_training=False)(x)
25
 
26
  forward_fn = jax.jit(hk.transform_with_state(fwd_).apply)
27
- with open(FLAGS.ckpt_dir / "duration_latest_ckpt.pickle", "rb") as f:
28
  dic = pickle.load(f)
29
  x = DurationInput(
30
  np.array(tokens, dtype=np.int32)[None, :],
@@ -58,8 +58,7 @@ def text2tokens(text, lexicon_fn):
58
  return tokens
59
 
60
 
61
- def predict_mel(tokens, durations):
62
- ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
63
  with open(ckpt_fn, "rb") as f:
64
  dic = pickle.load(f)
65
  last_step, params, aux, rng, optim_state = (
@@ -83,10 +82,14 @@ def predict_mel(tokens, durations):
83
 
84
 
85
  def text2mel(
86
- text: str, lexicon_fn=FLAGS.data_dir / "lexicon.txt", silence_duration: float = -1.0
 
 
 
 
87
  ):
88
  tokens = text2tokens(text, lexicon_fn)
89
- durations = predict_duration(tokens)
90
  durations = jnp.where(
91
  np.array(tokens)[None, :] == FLAGS.sil_index,
92
  jnp.clip(durations, a_min=silence_duration, a_max=None),
@@ -95,7 +98,7 @@ def text2mel(
95
  durations = jnp.where(
96
  np.array(tokens)[None, :] == FLAGS.word_end_index, 0.0, durations
97
  )
98
- mels = predict_mel(tokens, durations)
99
  if tokens[-1] == FLAGS.sil_index:
100
  end_silence = durations[0, -1].item()
101
  silence_frame = int(end_silence * FLAGS.sample_rate / (FLAGS.n_fft // 4))
@@ -114,4 +117,4 @@ if __name__ == "__main__":
114
  plt.savefig(str(args.output))
115
  plt.close()
116
  mel = jax.device_get(mel)
117
- mel.tofile("clip.mel")
 
19
  return dict(lines)
20
 
21
 
22
+ def predict_duration(tokens, ckpt_file):
23
  def fwd_(x):
24
  return DurationModel(is_training=False)(x)
25
 
26
  forward_fn = jax.jit(hk.transform_with_state(fwd_).apply)
27
+ with open(ckpt_file, "rb") as f:
28
  dic = pickle.load(f)
29
  x = DurationInput(
30
  np.array(tokens, dtype=np.int32)[None, :],
 
58
  return tokens
59
 
60
 
61
+ def predict_mel(tokens, durations, ckpt_fn):
 
62
  with open(ckpt_fn, "rb") as f:
63
  dic = pickle.load(f)
64
  last_step, params, aux, rng, optim_state = (
 
82
 
83
 
84
  def text2mel(
85
+ text: str,
86
+ lexicon_fn=FLAGS.data_dir / "lexicon.txt",
87
+ silence_duration: float = -1.0,
88
+ acoustic_ckpt=FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle",
89
+ duration_ckpt=FLAGS.ckpt_dir / "duration_latest_ckpt.pickle",
90
  ):
91
  tokens = text2tokens(text, lexicon_fn)
92
+ durations = predict_duration(tokens, duration_ckpt)
93
  durations = jnp.where(
94
  np.array(tokens)[None, :] == FLAGS.sil_index,
95
  jnp.clip(durations, a_min=silence_duration, a_max=None),
 
98
  durations = jnp.where(
99
  np.array(tokens)[None, :] == FLAGS.word_end_index, 0.0, durations
100
  )
101
+ mels = predict_mel(tokens, durations, acoustic_ckpt)
102
  if tokens[-1] == FLAGS.sil_index:
103
  end_silence = durations[0, -1].item()
104
  silence_frame = int(end_silence * FLAGS.sample_rate / (FLAGS.n_fft // 4))
 
117
  plt.savefig(str(args.output))
118
  plt.close()
119
  mel = jax.device_get(mel)
120
+ mel.tofile("clip.mel")