hieugiaosu
commited on
Commit
·
9160a00
1
Parent(s):
42e1680
update model
Browse files- network/models/TF_gridnet.py +102 -0
- network/models/__init__.py +1 -0
- network/modules/convolution_module.py +20 -0
network/models/TF_gridnet.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from ..modules.tf_gridnet_modules import *
|
3 |
+
from ..modules.input_tranformation import STFTInput, RMSNormalizeInput
|
4 |
+
from ..modules.output_transformation import WaveGeneratorByISTFT, RMSDenormalizeOutput
|
5 |
+
|
6 |
+
class TF_Gridnet(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
n_srcs=2,
|
10 |
+
n_fft=128,
|
11 |
+
hop_length=64,
|
12 |
+
window="hann",
|
13 |
+
n_audio_channel=1,
|
14 |
+
n_layers=6,
|
15 |
+
input_kernel_size_T = 3,
|
16 |
+
input_kernel_size_F = 3,
|
17 |
+
output_kernel_size_T = 3,
|
18 |
+
output_kernel_size_F = 3,
|
19 |
+
lstm_hidden_units=192,
|
20 |
+
attn_n_head=4,
|
21 |
+
qk_output_channel=4,
|
22 |
+
emb_dim=48,
|
23 |
+
emb_ks=4,
|
24 |
+
emb_hs=1,
|
25 |
+
activation="PReLU",
|
26 |
+
eps=1.0e-5,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.input_normalize = RMSNormalizeInput((1,2),keepdim=True)
|
31 |
+
self.stft = STFTInput(
|
32 |
+
n_fft=n_fft,
|
33 |
+
win_length=n_fft,
|
34 |
+
hop_length=hop_length,
|
35 |
+
window=window,
|
36 |
+
)
|
37 |
+
|
38 |
+
self.istft = WaveGeneratorByISTFT(
|
39 |
+
n_fft=n_fft,
|
40 |
+
win_length=n_fft,
|
41 |
+
hop_length=hop_length,
|
42 |
+
window=window
|
43 |
+
)
|
44 |
+
|
45 |
+
self.output_denormalize = RMSDenormalizeOutput()
|
46 |
+
|
47 |
+
self.dimension_embedding = DimensionEmbedding(
|
48 |
+
audio_channel=n_audio_channel,
|
49 |
+
emb_dim=emb_dim,
|
50 |
+
kernel_size=(input_kernel_size_F,input_kernel_size_T),
|
51 |
+
eps=eps
|
52 |
+
)
|
53 |
+
|
54 |
+
self.tf_gridnet_block = nn.Sequential(
|
55 |
+
*[
|
56 |
+
TFGridnetBlock(
|
57 |
+
emb_dim=emb_dim,
|
58 |
+
kernel_size=emb_ks,
|
59 |
+
emb_hop_size=emb_hs,
|
60 |
+
hidden_channels=lstm_hidden_units,
|
61 |
+
n_head=attn_n_head,
|
62 |
+
qk_output_channel=qk_output_channel,
|
63 |
+
activation=activation,
|
64 |
+
eps=eps
|
65 |
+
) for _ in range(n_layers)
|
66 |
+
]
|
67 |
+
)
|
68 |
+
|
69 |
+
self.deconv = TFGridnetDeconv(
|
70 |
+
emb_dim=emb_dim,
|
71 |
+
n_srcs=n_srcs,
|
72 |
+
kernel_size_T=output_kernel_size_T,
|
73 |
+
kernel_size_F=output_kernel_size_F,
|
74 |
+
padding_F=output_kernel_size_F//2,
|
75 |
+
padding_T=output_kernel_size_T//2
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(self,input):
|
79 |
+
audio_length = input.shape[-1]
|
80 |
+
|
81 |
+
x = input
|
82 |
+
|
83 |
+
if x.dim() == 2:
|
84 |
+
x = x.unsqueeze(1)
|
85 |
+
|
86 |
+
x, std = self.input_normalize(x)
|
87 |
+
|
88 |
+
x = self.stft(x)
|
89 |
+
|
90 |
+
x = self.dimension_embedding(x)
|
91 |
+
|
92 |
+
x = self.tf_gridnet_block(x)
|
93 |
+
|
94 |
+
x = self.deconv(x)
|
95 |
+
|
96 |
+
x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
|
97 |
+
|
98 |
+
x = self.istft(x,audio_length)
|
99 |
+
|
100 |
+
x = self.output_denormalize(x,std)
|
101 |
+
|
102 |
+
return x
|
network/models/__init__.py
CHANGED
@@ -1,2 +1,3 @@
|
|
|
|
1 |
from .TF_gridnet_with_condition import *
|
2 |
from .embedding_model import ResemblyzerVoiceEncoder
|
|
|
1 |
+
from .TF_gridnet import *
|
2 |
from .TF_gridnet_with_condition import *
|
3 |
from .embedding_model import ResemblyzerVoiceEncoder
|
network/modules/convolution_module.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
class SplitFeatureDeconv(nn.Module):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
emb_dim = 48,
|
8 |
+
n_srcs = 2,
|
9 |
+
kernel_size_T = 3,
|
10 |
+
kernel_size_F = 3,
|
11 |
+
padding_T = 1,
|
12 |
+
padding_F = 1,
|
13 |
+
) -> None:
|
14 |
+
super().__init__()
|
15 |
+
self.n_srcs = n_srcs
|
16 |
+
self.deconv = nn.ConvTranspose2d(emb_dim, n_srcs * emb_dim, (kernel_size_F,kernel_size_T), padding=(padding_F,padding_T))
|
17 |
+
def forward(self,input):
|
18 |
+
output = self.deconv(input)
|
19 |
+
output = rearrange(output,"B (N C) F T -> B C N F T", C=self.n_srcs)
|
20 |
+
return output
|