import haiku | |
import haiku as hk | |
import jax.numpy as jnp | |
import jax.random | |
from vietTTS.nat.config import FLAGS | |
from vietTTS.nat.model import DurationModel | |
def test_duration(): | |
net = DurationModel() | |
p = jnp.zeros((2, 10), dtype=jnp.int32) | |
l = jnp.zeros((2,), dtype=jnp.int32) | |
o = net(p, l) | |
assert o.shape == (2, 10, 1) | |