File size: 3,706 Bytes
12da6cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import haiku as hk
import jax
import jax.numpy as jnp
LRELU_SLOPE = 0.1
def get_padding(kernel_size, dilation=1):
p = int((kernel_size * dilation - dilation) / 2)
return ((p, p),)
class ResBlock1(hk.Module):
def __init__(
self, h, channels, kernel_size=3, dilation=(1, 3, 5), name="resblock1"
):
super().__init__(name=name)
self.h = h
self.convs1 = [
hk.Conv1D(
channels,
kernel_size,
1,
rate=dilation[i],
padding=get_padding(kernel_size, dilation[i]),
name=f"convs1_{i}",
)
for i in range(3)
]
self.convs2 = [
hk.Conv1D(
channels,
kernel_size,
1,
rate=1,
padding=get_padding(kernel_size, 1),
name=f"convs2_{i}",
)
for i in range(3)
]
def __call__(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = jax.nn.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
class ResBlock2(hk.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), name="ResBlock2"):
super().__init__(name=name)
self.h = h
self.convs = [
hk.Conv1D(
channels,
kernel_size,
1,
rate=dilation[i],
padding=get_padding(kernel_size, dilation[i]),
)
for i in range(2)
]
def __call__(self, x):
for c in self.convs:
xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
class Generator(hk.Module):
def __init__(self, h):
super().__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = hk.Conv1D(h.upsample_initial_channel, 7, 1, padding=((3, 3),))
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
self.ups = []
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
hk.Conv1DTranspose(
h.upsample_initial_channel // (2 ** (i + 1)),
kernel_shape=k,
stride=u,
padding="SAME",
name=f"ups_{i}",
)
)
self.resblocks = []
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(
resblock(h, ch, k, d, name=f"res_block1_{len(self.resblocks)}")
)
self.conv_post = hk.Conv1D(1, 7, 1, padding=((3, 3),))
def __call__(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = jax.nn.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = jax.nn.leaky_relu(x) # default pytorch value
x = self.conv_post(x)
x = jnp.tanh(x)
return x
|