tts / vietTTS /hifigan /model.py
tobiccino's picture
upload
12da6cc
raw
history blame
No virus
3.71 kB
import haiku as hk
import jax
import jax.numpy as jnp
LRELU_SLOPE = 0.1
def get_padding(kernel_size, dilation=1):
p = int((kernel_size * dilation - dilation) / 2)
return ((p, p),)
class ResBlock1(hk.Module):
def __init__(
self, h, channels, kernel_size=3, dilation=(1, 3, 5), name="resblock1"
):
super().__init__(name=name)
self.h = h
self.convs1 = [
hk.Conv1D(
channels,
kernel_size,
1,
rate=dilation[i],
padding=get_padding(kernel_size, dilation[i]),
name=f"convs1_{i}",
)
for i in range(3)
]
self.convs2 = [
hk.Conv1D(
channels,
kernel_size,
1,
rate=1,
padding=get_padding(kernel_size, 1),
name=f"convs2_{i}",
)
for i in range(3)
]
def __call__(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = jax.nn.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
class ResBlock2(hk.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), name="ResBlock2"):
super().__init__(name=name)
self.h = h
self.convs = [
hk.Conv1D(
channels,
kernel_size,
1,
rate=dilation[i],
padding=get_padding(kernel_size, dilation[i]),
)
for i in range(2)
]
def __call__(self, x):
for c in self.convs:
xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
class Generator(hk.Module):
def __init__(self, h):
super().__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = hk.Conv1D(h.upsample_initial_channel, 7, 1, padding=((3, 3),))
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
self.ups = []
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
hk.Conv1DTranspose(
h.upsample_initial_channel // (2 ** (i + 1)),
kernel_shape=k,
stride=u,
padding="SAME",
name=f"ups_{i}",
)
)
self.resblocks = []
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(
resblock(h, ch, k, d, name=f"res_block1_{len(self.resblocks)}")
)
self.conv_post = hk.Conv1D(1, 7, 1, padding=((3, 3),))
def __call__(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = jax.nn.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = jax.nn.leaky_relu(x) # default pytorch value
x = self.conv_post(x)
x = jnp.tanh(x)
return x