tts / tests /test_nat_acoustic.py
tobiccino's picture
upload
12da6cc
raw
history blame
557 Bytes
import haiku
import haiku as hk
import jax.numpy as jnp
import jax.random
from vietTTS.nat.config import FLAGS
from vietTTS.nat.model import AcousticModel
@hk.testing.transform_and_run
def test_duration():
net = AcousticModel()
token = jnp.zeros((2, 10), dtype=jnp.int32)
lengths = jnp.zeros((2,), dtype=jnp.int32)
durations = jnp.zeros((2, 10), dtype=jnp.float32)
mel = jnp.zeros((2, 20, 160), dtype=jnp.float32)
o1, o2 = net(token, mel, lengths, durations)
assert o1.shape == (2, 20, 160)
assert o2.shape == (2, 20, 160)