|
import json |
|
import os |
|
import pickle |
|
|
|
import haiku as hk |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
from .config import FLAGS |
|
from .model import Generator |
|
|
|
|
|
class AttrDict(dict): |
|
def __init__(self, *args, **kwargs): |
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
|
|
def mel2wave(mel): |
|
config_file = "assets/hifigan/config.json" |
|
MAX_WAV_VALUE = 32768.0 |
|
with open(config_file) as f: |
|
data = f.read() |
|
json_config = json.loads(data) |
|
h = AttrDict(json_config) |
|
|
|
@hk.transform_with_state |
|
def forward(x): |
|
net = Generator(h) |
|
return net(x) |
|
|
|
rng = next(hk.PRNGSequence(42)) |
|
|
|
with open(FLAGS.ckpt_dir / "hk_hifi.pickle", "rb") as f: |
|
params = pickle.load(f) |
|
aux = {} |
|
wav, aux = forward.apply(params, aux, rng, mel) |
|
wav = jnp.squeeze(wav) |
|
audio = jax.device_get(wav) |
|
return audio |
|
|