|
import json |
|
import pickle |
|
|
|
import haiku as hk |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
from .config import FLAGS |
|
from .model import Generator |
|
|
|
|
|
class AttrDict(dict): |
|
def __init__(self, *args, **kwargs): |
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
|
|
def mel2wave( |
|
mel, |
|
config_file=FLAGS.ckpt_dir / "config.json", |
|
ckpt_file=FLAGS.ckpt_dir / "hk_hifi.pickle", |
|
): |
|
with open(config_file) as f: |
|
data = f.read() |
|
json_config = json.loads(data) |
|
h = AttrDict(json_config) |
|
|
|
@hk.transform_with_state |
|
def forward(x): |
|
net = Generator(h) |
|
return net(x) |
|
|
|
rng = next(hk.PRNGSequence(42)) |
|
|
|
with open(ckpt_file, "rb") as f: |
|
params = pickle.load(f) |
|
aux = {} |
|
wav, aux = forward.apply(params, aux, rng, mel) |
|
wav = jnp.squeeze(wav) |
|
print("wav : ") |
|
print(wav) |
|
jax.config.update('jax_platform_name', 'cpu') |
|
audio = jax.device_get(wav) |
|
print("audio : ") |
|
print(audio) |
|
return audio |