File size: 370 Bytes
12da6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import haiku
import haiku as hk
import jax.numpy as jnp
import jax.random
from vietTTS.nat.config import FLAGS
from vietTTS.nat.model import DurationModel


@hk.testing.transform_and_run
def test_duration():
    net = DurationModel()
    p = jnp.zeros((2, 10), dtype=jnp.int32)
    l = jnp.zeros((2,), dtype=jnp.int32)
    o = net(p, l)
    assert o.shape == (2, 10, 1)