File size: 370 Bytes
12da6cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import haiku
import haiku as hk
import jax.numpy as jnp
import jax.random
from vietTTS.nat.config import FLAGS
from vietTTS.nat.model import DurationModel
@hk.testing.transform_and_run
def test_duration():
net = DurationModel()
p = jnp.zeros((2, 10), dtype=jnp.int32)
l = jnp.zeros((2,), dtype=jnp.int32)
o = net(p, l)
assert o.shape == (2, 10, 1)
|