File size: 3,706 Bytes
12da6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import haiku as hk
import jax
import jax.numpy as jnp

LRELU_SLOPE = 0.1


def get_padding(kernel_size, dilation=1):
    p = int((kernel_size * dilation - dilation) / 2)
    return ((p, p),)


class ResBlock1(hk.Module):
    def __init__(
        self, h, channels, kernel_size=3, dilation=(1, 3, 5), name="resblock1"
    ):
        super().__init__(name=name)

        self.h = h
        self.convs1 = [
            hk.Conv1D(
                channels,
                kernel_size,
                1,
                rate=dilation[i],
                padding=get_padding(kernel_size, dilation[i]),
                name=f"convs1_{i}",
            )
            for i in range(3)
        ]

        self.convs2 = [
            hk.Conv1D(
                channels,
                kernel_size,
                1,
                rate=1,
                padding=get_padding(kernel_size, 1),
                name=f"convs2_{i}",
            )
            for i in range(3)
        ]

    def __call__(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
            xt = c1(xt)
            xt = jax.nn.leaky_relu(xt, LRELU_SLOPE)
            xt = c2(xt)
            x = xt + x
        return x


class ResBlock2(hk.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), name="ResBlock2"):
        super().__init__(name=name)
        self.h = h
        self.convs = [
            hk.Conv1D(
                channels,
                kernel_size,
                1,
                rate=dilation[i],
                padding=get_padding(kernel_size, dilation[i]),
            )
            for i in range(2)
        ]

    def __call__(self, x):
        for c in self.convs:
            xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
            xt = c(xt)
            x = xt + x
        return x


class Generator(hk.Module):
    def __init__(self, h):
        super().__init__()
        self.h = h
        self.num_kernels = len(h.resblock_kernel_sizes)
        self.num_upsamples = len(h.upsample_rates)
        self.conv_pre = hk.Conv1D(h.upsample_initial_channel, 7, 1, padding=((3, 3),))
        resblock = ResBlock1 if h.resblock == "1" else ResBlock2
        self.ups = []
        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
            self.ups.append(
                hk.Conv1DTranspose(
                    h.upsample_initial_channel // (2 ** (i + 1)),
                    kernel_shape=k,
                    stride=u,
                    padding="SAME",
                    name=f"ups_{i}",
                )
            )

        self.resblocks = []

        for i in range(len(self.ups)):
            ch = h.upsample_initial_channel // (2 ** (i + 1))
            for j, (k, d) in enumerate(
                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
            ):
                self.resblocks.append(
                    resblock(h, ch, k, d, name=f"res_block1_{len(self.resblocks)}")
                )
        self.conv_post = hk.Conv1D(1, 7, 1, padding=((3, 3),))

    def __call__(self, x):
        x = self.conv_pre(x)
        for i in range(self.num_upsamples):
            x = jax.nn.leaky_relu(x, LRELU_SLOPE)

            x = self.ups[i](x)
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i * self.num_kernels + j](x)
                else:
                    xs += self.resblocks[i * self.num_kernels + j](x)
            x = xs / self.num_kernels
        x = jax.nn.leaky_relu(x)  # default pytorch value
        x = self.conv_post(x)
        x = jnp.tanh(x)
        return x