groadabike
commited on
Upload tasnet.py
Browse files
tasnet.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Created on 2018/12
|
8 |
+
# Author: Kaituo XU
|
9 |
+
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
|
10 |
+
# Here is the original license:
|
11 |
+
# The MIT License (MIT)
|
12 |
+
#
|
13 |
+
# Copyright (c) 2018 Kaituo XU
|
14 |
+
#
|
15 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
16 |
+
# of this software and associated documentation files (the "Software"), to deal
|
17 |
+
# in the Software without restriction, including without limitation the rights
|
18 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
19 |
+
# copies of the Software, and to permit persons to whom the Software is
|
20 |
+
# furnished to do so, subject to the following conditions:
|
21 |
+
#
|
22 |
+
# The above copyright notice and this permission notice shall be included in all
|
23 |
+
# copies or substantial portions of the Software.
|
24 |
+
#
|
25 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
26 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
27 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
28 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
29 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
30 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
31 |
+
# SOFTWARE.
|
32 |
+
|
33 |
+
import math
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import torch.nn as nn
|
37 |
+
import torch.nn.functional as F
|
38 |
+
from huggingface_hub import PyTorchModelHubMixin
|
39 |
+
|
40 |
+
EPS = 1e-8
|
41 |
+
|
42 |
+
|
43 |
+
def overlap_and_add(signal, frame_step):
|
44 |
+
outer_dimensions = signal.size()[:-2]
|
45 |
+
frames, frame_length = signal.size()[-2:]
|
46 |
+
|
47 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
48 |
+
subframe_step = frame_step // subframe_length
|
49 |
+
subframes_per_frame = frame_length // subframe_length
|
50 |
+
output_size = frame_step * (frames - 1) + frame_length
|
51 |
+
output_subframes = output_size // subframe_length
|
52 |
+
|
53 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
54 |
+
|
55 |
+
frame = torch.arange(0, output_subframes, device=signal.device).unfold(
|
56 |
+
0, subframes_per_frame, subframe_step
|
57 |
+
)
|
58 |
+
frame = frame.long() # signal may in GPU or CPU
|
59 |
+
frame = frame.contiguous().view(-1)
|
60 |
+
|
61 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
62 |
+
result.index_add_(-2, frame, subframe_signal)
|
63 |
+
result = result.view(*outer_dimensions, -1)
|
64 |
+
return result
|
65 |
+
|
66 |
+
|
67 |
+
class ConvTasNetStereo(nn.Module, PyTorchModelHubMixin):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
N=256,
|
71 |
+
L=20,
|
72 |
+
B=256,
|
73 |
+
H=512,
|
74 |
+
P=3,
|
75 |
+
X=8,
|
76 |
+
R=4,
|
77 |
+
C=2,
|
78 |
+
audio_channels=2,
|
79 |
+
samplerate=44100,
|
80 |
+
norm_type="gLN",
|
81 |
+
causal=False,
|
82 |
+
mask_nonlinear="relu",
|
83 |
+
):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
N: Number of filters in autoencoder
|
87 |
+
L: Length of the filters (in samples)
|
88 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
89 |
+
H: Number of channels in convolutional blocks
|
90 |
+
P: Kernel size in convolutional blocks
|
91 |
+
X: Number of convolutional blocks in each repeat
|
92 |
+
R: Number of repeats
|
93 |
+
C: Number of speakers
|
94 |
+
norm_type: BN, gLN, cLN
|
95 |
+
causal: causal or non-causal
|
96 |
+
mask_nonlinear: use which non-linear function to generate mask
|
97 |
+
"""
|
98 |
+
super(ConvTasNetStereo, self).__init__()
|
99 |
+
# Hyper-parameter
|
100 |
+
self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = (
|
101 |
+
N,
|
102 |
+
L,
|
103 |
+
B,
|
104 |
+
H,
|
105 |
+
P,
|
106 |
+
X,
|
107 |
+
R,
|
108 |
+
C,
|
109 |
+
)
|
110 |
+
self.norm_type = norm_type
|
111 |
+
self.causal = causal
|
112 |
+
self.mask_nonlinear = mask_nonlinear
|
113 |
+
self.audio_channels = audio_channels
|
114 |
+
self.samplerate = samplerate
|
115 |
+
# Components
|
116 |
+
self.encoder = Encoder(L, N, audio_channels)
|
117 |
+
self.separator = TemporalConvNet(
|
118 |
+
N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear
|
119 |
+
)
|
120 |
+
self.decoder = Decoder(N, L, audio_channels)
|
121 |
+
# init
|
122 |
+
for p in self.parameters():
|
123 |
+
if p.dim() > 1:
|
124 |
+
nn.init.xavier_normal_(p)
|
125 |
+
|
126 |
+
def valid_length(self, length):
|
127 |
+
return length
|
128 |
+
|
129 |
+
def forward(self, mixture):
|
130 |
+
"""
|
131 |
+
Args:
|
132 |
+
mixture: [M, T], M is batch size, T is #samples
|
133 |
+
Returns:
|
134 |
+
est_source: [M, C, T]
|
135 |
+
"""
|
136 |
+
mixture_w = self.encoder(mixture)
|
137 |
+
est_mask = self.separator(mixture_w)
|
138 |
+
est_source = self.decoder(mixture_w, est_mask)
|
139 |
+
|
140 |
+
# T changed after conv1d in encoder, fix it here
|
141 |
+
T_origin = mixture.size(-1)
|
142 |
+
T_conv = est_source.size(-1)
|
143 |
+
est_source = F.pad(est_source, (0, T_origin - T_conv))
|
144 |
+
return est_source
|
145 |
+
|
146 |
+
def serialize(self):
|
147 |
+
"""Serialize model and output dictionary.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
dict, serialized model with keys `model_args` and `state_dict`.
|
151 |
+
"""
|
152 |
+
import pytorch_lightning as pl # Not used in torch.hub
|
153 |
+
|
154 |
+
model_conf = dict(
|
155 |
+
model_name=self.__class__.__name__,
|
156 |
+
state_dict=self.get_state_dict(),
|
157 |
+
# model_args=self.get_model_args(),
|
158 |
+
)
|
159 |
+
# Additional infos
|
160 |
+
infos = dict()
|
161 |
+
infos["software_versions"] = dict(
|
162 |
+
torch_version=torch.__version__,
|
163 |
+
pytorch_lightning_version=pl.__version__,
|
164 |
+
asteroid_version="0.7.0",
|
165 |
+
)
|
166 |
+
model_conf["infos"] = infos
|
167 |
+
return model_conf
|
168 |
+
|
169 |
+
def get_state_dict(self):
|
170 |
+
"""In case the state dict needs to be modified before sharing the model."""
|
171 |
+
return self.state_dict()
|
172 |
+
|
173 |
+
def get_model_args(self):
|
174 |
+
"""Arguments needed to re-instantiate the model."""
|
175 |
+
fb_config = self.encoder.filterbank.get_config()
|
176 |
+
masknet_config = self.masker.get_config()
|
177 |
+
# Assert both dict are disjoint
|
178 |
+
if not all(k not in fb_config for k in masknet_config):
|
179 |
+
raise AssertionError(
|
180 |
+
"Filterbank and Mask network config share common keys. Merging them is not safe."
|
181 |
+
)
|
182 |
+
# Merge all args under model_args.
|
183 |
+
model_args = {
|
184 |
+
**fb_config,
|
185 |
+
**masknet_config,
|
186 |
+
"encoder_activation": self.encoder_activation,
|
187 |
+
}
|
188 |
+
return model_args
|
189 |
+
|
190 |
+
|
191 |
+
class Encoder(nn.Module):
|
192 |
+
"""Estimation of the nonnegative mixture weight by a 1-D conv layer."""
|
193 |
+
|
194 |
+
def __init__(self, L, N, audio_channels):
|
195 |
+
super(Encoder, self).__init__()
|
196 |
+
# Hyper-parameter
|
197 |
+
self.L, self.N = L, N
|
198 |
+
# Components
|
199 |
+
# 50% overlap
|
200 |
+
self.conv1d_U = nn.Conv1d(
|
201 |
+
audio_channels, N, kernel_size=L, stride=L // 2, bias=False
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, mixture):
|
205 |
+
"""
|
206 |
+
Args:
|
207 |
+
mixture: [M, T], M is batch size, T is #samples
|
208 |
+
Returns:
|
209 |
+
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
|
210 |
+
"""
|
211 |
+
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
212 |
+
return mixture_w
|
213 |
+
|
214 |
+
|
215 |
+
class Decoder(nn.Module):
|
216 |
+
def __init__(self, N, L, audio_channels):
|
217 |
+
super(Decoder, self).__init__()
|
218 |
+
# Hyper-parameter
|
219 |
+
self.N, self.L = N, L
|
220 |
+
self.audio_channels = audio_channels
|
221 |
+
# Components
|
222 |
+
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
|
223 |
+
|
224 |
+
def forward(self, mixture_w, est_mask):
|
225 |
+
"""
|
226 |
+
Args:
|
227 |
+
mixture_w: [M, N, K]
|
228 |
+
est_mask: [M, C, N, K]
|
229 |
+
Returns:
|
230 |
+
est_source: [M, C, T]
|
231 |
+
"""
|
232 |
+
# D = W * M
|
233 |
+
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
|
234 |
+
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
|
235 |
+
# S = DV
|
236 |
+
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
|
237 |
+
m, c, k, _ = est_source.size()
|
238 |
+
est_source = (
|
239 |
+
est_source.view(m, c, k, self.audio_channels, -1)
|
240 |
+
.transpose(2, 3)
|
241 |
+
.contiguous()
|
242 |
+
)
|
243 |
+
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
|
244 |
+
return est_source
|
245 |
+
|
246 |
+
|
247 |
+
class TemporalConvNet(nn.Module):
|
248 |
+
def __init__(
|
249 |
+
self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu"
|
250 |
+
):
|
251 |
+
"""
|
252 |
+
Args:
|
253 |
+
N: Number of filters in autoencoder
|
254 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
255 |
+
H: Number of channels in convolutional blocks
|
256 |
+
P: Kernel size in convolutional blocks
|
257 |
+
X: Number of convolutional blocks in each repeat
|
258 |
+
R: Number of repeats
|
259 |
+
C: Number of speakers
|
260 |
+
norm_type: BN, gLN, cLN
|
261 |
+
causal: causal or non-causal
|
262 |
+
mask_nonlinear: use which non-linear function to generate mask
|
263 |
+
"""
|
264 |
+
super(TemporalConvNet, self).__init__()
|
265 |
+
# Hyper-parameter
|
266 |
+
self.C = C
|
267 |
+
self.mask_nonlinear = mask_nonlinear
|
268 |
+
# Components
|
269 |
+
# [M, N, K] -> [M, N, K]
|
270 |
+
layer_norm = ChannelwiseLayerNorm(N)
|
271 |
+
# [M, N, K] -> [M, B, K]
|
272 |
+
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
273 |
+
# [M, B, K] -> [M, B, K]
|
274 |
+
repeats = []
|
275 |
+
for r in range(R):
|
276 |
+
blocks = []
|
277 |
+
for x in range(X):
|
278 |
+
dilation = 2**x
|
279 |
+
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
280 |
+
blocks += [
|
281 |
+
TemporalBlock(
|
282 |
+
B,
|
283 |
+
H,
|
284 |
+
P,
|
285 |
+
stride=1,
|
286 |
+
padding=padding,
|
287 |
+
dilation=dilation,
|
288 |
+
norm_type=norm_type,
|
289 |
+
causal=causal,
|
290 |
+
)
|
291 |
+
]
|
292 |
+
repeats += [nn.Sequential(*blocks)]
|
293 |
+
temporal_conv_net = nn.Sequential(*repeats)
|
294 |
+
# [M, B, K] -> [M, C*N, K]
|
295 |
+
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
|
296 |
+
# Put together
|
297 |
+
self.network = nn.Sequential(
|
298 |
+
layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1
|
299 |
+
)
|
300 |
+
|
301 |
+
def forward(self, mixture_w):
|
302 |
+
"""
|
303 |
+
Keep this API same with TasNet
|
304 |
+
Args:
|
305 |
+
mixture_w: [M, N, K], M is batch size
|
306 |
+
returns:
|
307 |
+
est_mask: [M, C, N, K]
|
308 |
+
"""
|
309 |
+
M, N, K = mixture_w.size()
|
310 |
+
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
311 |
+
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
312 |
+
if self.mask_nonlinear == "softmax":
|
313 |
+
est_mask = F.softmax(score, dim=1)
|
314 |
+
elif self.mask_nonlinear == "relu":
|
315 |
+
est_mask = F.relu(score)
|
316 |
+
else:
|
317 |
+
raise ValueError("Unsupported mask non-linear function")
|
318 |
+
return est_mask
|
319 |
+
|
320 |
+
|
321 |
+
class TemporalBlock(nn.Module):
|
322 |
+
def __init__(
|
323 |
+
self,
|
324 |
+
in_channels,
|
325 |
+
out_channels,
|
326 |
+
kernel_size,
|
327 |
+
stride,
|
328 |
+
padding,
|
329 |
+
dilation,
|
330 |
+
norm_type="gLN",
|
331 |
+
causal=False,
|
332 |
+
):
|
333 |
+
super(TemporalBlock, self).__init__()
|
334 |
+
# [M, B, K] -> [M, H, K]
|
335 |
+
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
336 |
+
prelu = nn.PReLU()
|
337 |
+
norm = chose_norm(norm_type, out_channels)
|
338 |
+
# [M, H, K] -> [M, B, K]
|
339 |
+
dsconv = DepthwiseSeparableConv(
|
340 |
+
out_channels,
|
341 |
+
in_channels,
|
342 |
+
kernel_size,
|
343 |
+
stride,
|
344 |
+
padding,
|
345 |
+
dilation,
|
346 |
+
norm_type,
|
347 |
+
causal,
|
348 |
+
)
|
349 |
+
# Put together
|
350 |
+
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
351 |
+
|
352 |
+
def forward(self, x):
|
353 |
+
"""
|
354 |
+
Args:
|
355 |
+
x: [M, B, K]
|
356 |
+
Returns:
|
357 |
+
[M, B, K]
|
358 |
+
"""
|
359 |
+
residual = x
|
360 |
+
out = self.net(x)
|
361 |
+
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
362 |
+
return out + residual # look like w/o F.relu is better than w/ F.relu
|
363 |
+
# return F.relu(out + residual)
|
364 |
+
|
365 |
+
|
366 |
+
class DepthwiseSeparableConv(nn.Module):
|
367 |
+
def __init__(
|
368 |
+
self,
|
369 |
+
in_channels,
|
370 |
+
out_channels,
|
371 |
+
kernel_size,
|
372 |
+
stride,
|
373 |
+
padding,
|
374 |
+
dilation,
|
375 |
+
norm_type="gLN",
|
376 |
+
causal=False,
|
377 |
+
):
|
378 |
+
super(DepthwiseSeparableConv, self).__init__()
|
379 |
+
# Use `groups` option to implement depthwise convolution
|
380 |
+
# [M, H, K] -> [M, H, K]
|
381 |
+
depthwise_conv = nn.Conv1d(
|
382 |
+
in_channels,
|
383 |
+
in_channels,
|
384 |
+
kernel_size,
|
385 |
+
stride=stride,
|
386 |
+
padding=padding,
|
387 |
+
dilation=dilation,
|
388 |
+
groups=in_channels,
|
389 |
+
bias=False,
|
390 |
+
)
|
391 |
+
if causal:
|
392 |
+
chomp = Chomp1d(padding)
|
393 |
+
prelu = nn.PReLU()
|
394 |
+
norm = chose_norm(norm_type, in_channels)
|
395 |
+
# [M, H, K] -> [M, B, K]
|
396 |
+
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
397 |
+
# Put together
|
398 |
+
if causal:
|
399 |
+
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
|
400 |
+
else:
|
401 |
+
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
|
402 |
+
|
403 |
+
def forward(self, x):
|
404 |
+
"""
|
405 |
+
Args:
|
406 |
+
x: [M, H, K]
|
407 |
+
Returns:
|
408 |
+
result: [M, B, K]
|
409 |
+
"""
|
410 |
+
return self.net(x)
|
411 |
+
|
412 |
+
|
413 |
+
class Chomp1d(nn.Module):
|
414 |
+
"""To ensure the output length is the same as the input."""
|
415 |
+
|
416 |
+
def __init__(self, chomp_size):
|
417 |
+
super(Chomp1d, self).__init__()
|
418 |
+
self.chomp_size = chomp_size
|
419 |
+
|
420 |
+
def forward(self, x):
|
421 |
+
"""
|
422 |
+
Args:
|
423 |
+
x: [M, H, Kpad]
|
424 |
+
Returns:
|
425 |
+
[M, H, K]
|
426 |
+
"""
|
427 |
+
return x[:, :, : -self.chomp_size].contiguous()
|
428 |
+
|
429 |
+
|
430 |
+
def chose_norm(norm_type, channel_size):
|
431 |
+
"""The input of normlization will be (M, C, K), where M is batch size,
|
432 |
+
C is channel size and K is sequence length.
|
433 |
+
"""
|
434 |
+
if norm_type == "gLN":
|
435 |
+
return GlobalLayerNorm(channel_size)
|
436 |
+
elif norm_type == "cLN":
|
437 |
+
return ChannelwiseLayerNorm(channel_size)
|
438 |
+
elif norm_type == "id":
|
439 |
+
return nn.Identity()
|
440 |
+
else: # norm_type == "BN":
|
441 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
442 |
+
# along M and K, so this BN usage is right.
|
443 |
+
return nn.BatchNorm1d(channel_size)
|
444 |
+
|
445 |
+
|
446 |
+
# TODO: Use nn.LayerNorm to impl cLN to speed up
|
447 |
+
class ChannelwiseLayerNorm(nn.Module):
|
448 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
449 |
+
|
450 |
+
def __init__(self, channel_size):
|
451 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
452 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
453 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
454 |
+
self.reset_parameters()
|
455 |
+
|
456 |
+
def reset_parameters(self):
|
457 |
+
self.gamma.data.fill_(1)
|
458 |
+
self.beta.data.zero_()
|
459 |
+
|
460 |
+
def forward(self, y):
|
461 |
+
"""
|
462 |
+
Args:
|
463 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
464 |
+
Returns:
|
465 |
+
cLN_y: [M, N, K]
|
466 |
+
"""
|
467 |
+
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
468 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
469 |
+
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
470 |
+
return cLN_y
|
471 |
+
|
472 |
+
|
473 |
+
class GlobalLayerNorm(nn.Module):
|
474 |
+
"""Global Layer Normalization (gLN)"""
|
475 |
+
|
476 |
+
def __init__(self, channel_size):
|
477 |
+
super(GlobalLayerNorm, self).__init__()
|
478 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
479 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
480 |
+
self.reset_parameters()
|
481 |
+
|
482 |
+
def reset_parameters(self):
|
483 |
+
self.gamma.data.fill_(1)
|
484 |
+
self.beta.data.zero_()
|
485 |
+
|
486 |
+
def forward(self, y):
|
487 |
+
"""
|
488 |
+
Args:
|
489 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
490 |
+
Returns:
|
491 |
+
gLN_y: [M, N, K]
|
492 |
+
"""
|
493 |
+
# TODO: in torch 1.0, torch.mean() support dim list
|
494 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
|
495 |
+
var = (
|
496 |
+
(torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
497 |
+
)
|
498 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
499 |
+
return gLN_y
|
500 |
+
|
501 |
+
|
502 |
+
if __name__ == "__main__":
|
503 |
+
torch.manual_seed(123)
|
504 |
+
M, N, L, T = 2, 3, 4, 12
|
505 |
+
K = 2 * T // L - 1
|
506 |
+
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
|
507 |
+
mixture = torch.randint(3, (M, T))
|
508 |
+
# test Encoder
|
509 |
+
encoder = Encoder(L, N)
|
510 |
+
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
|
511 |
+
mixture_w = encoder(mixture)
|
512 |
+
print("mixture", mixture)
|
513 |
+
print("U", encoder.conv1d_U.weight)
|
514 |
+
print("mixture_w", mixture_w)
|
515 |
+
print("mixture_w size", mixture_w.size())
|
516 |
+
|
517 |
+
# test TemporalConvNet
|
518 |
+
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
|
519 |
+
est_mask = separator(mixture_w)
|
520 |
+
print("est_mask", est_mask)
|
521 |
+
|
522 |
+
# test Decoder
|
523 |
+
decoder = Decoder(N, L)
|
524 |
+
est_mask = torch.randint(2, (B, K, C, N))
|
525 |
+
est_source = decoder(mixture_w, est_mask)
|
526 |
+
print("est_source", est_source)
|
527 |
+
|
528 |
+
# test Conv-TasNet
|
529 |
+
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
|
530 |
+
est_source = conv_tasnet(mixture)
|
531 |
+
print("est_source", est_source)
|
532 |
+
print("est_source size", est_source.size())
|