tts / vietTTS /hifigan /mel2wave.py
tobiccino's picture
add stop duration option
320b21c
raw
history blame
930 Bytes
import json
import pickle
import haiku as hk
import jax
import jax.numpy as jnp
from .config import FLAGS
from .model import Generator
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def mel2wave(
mel,
config_file=FLAGS.ckpt_dir / "config.json",
ckpt_file=FLAGS.ckpt_dir / "hk_hifi.pickle",
):
with open(config_file) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
@hk.transform_with_state
def forward(x):
net = Generator(h)
return net(x)
rng = next(hk.PRNGSequence(42))
with open(ckpt_file, "rb") as f:
params = pickle.load(f)
aux = {}
wav, aux = forward.apply(params, aux, rng, mel)
wav = jnp.squeeze(wav)
jax.config.update('jax_platform_name', 'cpu')
audio = jax.device_get(wav)
return audio