hieugiaosu commited on
Commit
9160a00
·
1 Parent(s): 42e1680

update model

Browse files
network/models/TF_gridnet.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from ..modules.tf_gridnet_modules import *
3
+ from ..modules.input_tranformation import STFTInput, RMSNormalizeInput
4
+ from ..modules.output_transformation import WaveGeneratorByISTFT, RMSDenormalizeOutput
5
+
6
+ class TF_Gridnet(nn.Module):
7
+ def __init__(
8
+ self,
9
+ n_srcs=2,
10
+ n_fft=128,
11
+ hop_length=64,
12
+ window="hann",
13
+ n_audio_channel=1,
14
+ n_layers=6,
15
+ input_kernel_size_T = 3,
16
+ input_kernel_size_F = 3,
17
+ output_kernel_size_T = 3,
18
+ output_kernel_size_F = 3,
19
+ lstm_hidden_units=192,
20
+ attn_n_head=4,
21
+ qk_output_channel=4,
22
+ emb_dim=48,
23
+ emb_ks=4,
24
+ emb_hs=1,
25
+ activation="PReLU",
26
+ eps=1.0e-5,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.input_normalize = RMSNormalizeInput((1,2),keepdim=True)
31
+ self.stft = STFTInput(
32
+ n_fft=n_fft,
33
+ win_length=n_fft,
34
+ hop_length=hop_length,
35
+ window=window,
36
+ )
37
+
38
+ self.istft = WaveGeneratorByISTFT(
39
+ n_fft=n_fft,
40
+ win_length=n_fft,
41
+ hop_length=hop_length,
42
+ window=window
43
+ )
44
+
45
+ self.output_denormalize = RMSDenormalizeOutput()
46
+
47
+ self.dimension_embedding = DimensionEmbedding(
48
+ audio_channel=n_audio_channel,
49
+ emb_dim=emb_dim,
50
+ kernel_size=(input_kernel_size_F,input_kernel_size_T),
51
+ eps=eps
52
+ )
53
+
54
+ self.tf_gridnet_block = nn.Sequential(
55
+ *[
56
+ TFGridnetBlock(
57
+ emb_dim=emb_dim,
58
+ kernel_size=emb_ks,
59
+ emb_hop_size=emb_hs,
60
+ hidden_channels=lstm_hidden_units,
61
+ n_head=attn_n_head,
62
+ qk_output_channel=qk_output_channel,
63
+ activation=activation,
64
+ eps=eps
65
+ ) for _ in range(n_layers)
66
+ ]
67
+ )
68
+
69
+ self.deconv = TFGridnetDeconv(
70
+ emb_dim=emb_dim,
71
+ n_srcs=n_srcs,
72
+ kernel_size_T=output_kernel_size_T,
73
+ kernel_size_F=output_kernel_size_F,
74
+ padding_F=output_kernel_size_F//2,
75
+ padding_T=output_kernel_size_T//2
76
+ )
77
+
78
+ def forward(self,input):
79
+ audio_length = input.shape[-1]
80
+
81
+ x = input
82
+
83
+ if x.dim() == 2:
84
+ x = x.unsqueeze(1)
85
+
86
+ x, std = self.input_normalize(x)
87
+
88
+ x = self.stft(x)
89
+
90
+ x = self.dimension_embedding(x)
91
+
92
+ x = self.tf_gridnet_block(x)
93
+
94
+ x = self.deconv(x)
95
+
96
+ x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
97
+
98
+ x = self.istft(x,audio_length)
99
+
100
+ x = self.output_denormalize(x,std)
101
+
102
+ return x
network/models/__init__.py CHANGED
@@ -1,2 +1,3 @@
 
1
  from .TF_gridnet_with_condition import *
2
  from .embedding_model import ResemblyzerVoiceEncoder
 
1
+ from .TF_gridnet import *
2
  from .TF_gridnet_with_condition import *
3
  from .embedding_model import ResemblyzerVoiceEncoder
network/modules/convolution_module.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from einops import rearrange
3
+
4
+ class SplitFeatureDeconv(nn.Module):
5
+ def __init__(
6
+ self,
7
+ emb_dim = 48,
8
+ n_srcs = 2,
9
+ kernel_size_T = 3,
10
+ kernel_size_F = 3,
11
+ padding_T = 1,
12
+ padding_F = 1,
13
+ ) -> None:
14
+ super().__init__()
15
+ self.n_srcs = n_srcs
16
+ self.deconv = nn.ConvTranspose2d(emb_dim, n_srcs * emb_dim, (kernel_size_F,kernel_size_T), padding=(padding_F,padding_T))
17
+ def forward(self,input):
18
+ output = self.deconv(input)
19
+ output = rearrange(output,"B (N C) F T -> B C N F T", C=self.n_srcs)
20
+ return output