import haiku import haiku as hk import jax.numpy as jnp import jax.random from vietTTS.nat.config import FLAGS from vietTTS.nat.model import AcousticModel @hk.testing.transform_and_run def test_duration(): net = AcousticModel() token = jnp.zeros((2, 10), dtype=jnp.int32) lengths = jnp.zeros((2,), dtype=jnp.int32) durations = jnp.zeros((2, 10), dtype=jnp.float32) mel = jnp.zeros((2, 20, 160), dtype=jnp.float32) o1, o2 = net(token, mel, lengths, durations) assert o1.shape == (2, 20, 160) assert o2.shape == (2, 20, 160)