hieugiaosu commited on
Commit
7596274
·
1 Parent(s): e9096c9

Add application file

Browse files
Files changed (41) hide show
  1. app.py +102 -0
  2. checkpoints/concat_emb.pth +3 -0
  3. network/__init__.py +1 -0
  4. network/layers/ISTFT_layer.py +69 -0
  5. network/layers/STFT_Layer.py +67 -0
  6. network/layers/__init__.py +3 -0
  7. network/layers/__pycache__/ISTFT_layer.cpython-311.pyc +0 -0
  8. network/layers/__pycache__/STFTLayer.cpython-311.pyc +0 -0
  9. network/layers/__pycache__/STFT_Layer.cpython-311.pyc +0 -0
  10. network/layers/__pycache__/__init__.cpython-311.pyc +0 -0
  11. network/layers/__pycache__/film_layer.cpython-311.pyc +0 -0
  12. network/layers/film_layer.py +21 -0
  13. network/models/TF_gridnet_with_condition.py +638 -0
  14. network/models/__init__.py +2 -0
  15. network/models/__pycache__/TF_gridnet.cpython-311.pyc +0 -0
  16. network/models/__pycache__/TF_gridnet_with_condition.cpython-311.pyc +0 -0
  17. network/models/__pycache__/TF_gridnet_with_condition_new.cpython-311.pyc +0 -0
  18. network/models/__pycache__/__init__.cpython-311.pyc +0 -0
  19. network/models/__pycache__/embedding_model.cpython-311.pyc +0 -0
  20. network/models/__pycache__/new_style.cpython-311.pyc +0 -0
  21. network/models/__pycache__/tfGrideNetLOTH.cpython-311.pyc +0 -0
  22. network/models/embedding_model.py +14 -0
  23. network/modules/__init__.py +0 -0
  24. network/modules/attention.py +197 -0
  25. network/modules/gate_module.py +16 -0
  26. network/modules/input_tranformation.py +92 -0
  27. network/modules/output_transformation.py +47 -0
  28. network/modules/sequence_embed.py +138 -0
  29. network/modules/split_modules.py +121 -0
  30. network/modules/tf_gridnet_modules/__init__.py +3 -0
  31. network/modules/tf_gridnet_modules/__pycache__/__init__.cpython-311.pyc +0 -0
  32. network/modules/tf_gridnet_modules/__pycache__/deconv.cpython-311.pyc +0 -0
  33. network/modules/tf_gridnet_modules/__pycache__/dimension_embedding.cpython-311.pyc +0 -0
  34. network/modules/tf_gridnet_modules/__pycache__/tf_gridnet_block.cpython-311.pyc +0 -0
  35. network/modules/tf_gridnet_modules/deconv.py +20 -0
  36. network/modules/tf_gridnet_modules/dimension_embedding.py +15 -0
  37. network/modules/tf_gridnet_modules/tf_gridnet_block.py +255 -0
  38. network/utils/__init__.py +2 -0
  39. network/utils/enum_declare.py +6 -0
  40. network/utils/error_message.py +7 -0
  41. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from torch.cuda.amp import autocast
5
+
6
+ from network.models import FilterBandTFGridnet, ResemblyzerVoiceEncoder
7
+
8
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+ device = 'cpu'
10
+ model = FilterBandTFGridnet(n_layers=5,conditional_dim=256*2)
11
+ emb = ResemblyzerVoiceEncoder(device=device)
12
+ mixed_voice_tool = None
13
+
14
+ def load_voice(voice_path):
15
+ voice, rate = torchaudio.load(voice_path)
16
+
17
+ if rate != 16000:
18
+ voice = torchaudio.functional.resample(voice, rate, 8000)
19
+ rate = 16000
20
+ voice = voice.float()
21
+ return voice, rate
22
+
23
+ def mix(voice1_path, voice2_path, snr=0):
24
+ global mixed_voice_tool
25
+ voice1, _ = load_voice(voice1_path)
26
+ voice2, _ = load_voice(voice2_path)
27
+ mix = torchaudio.functional.add_noise(voice1, voice2, torch.tensor([float(snr)])).float()
28
+ mixed_voice_tool = mix
29
+ return gr.Audio(tuple((16000,mix[0].numpy())),type='numpy')
30
+
31
+ # def seprate_from_file(mixed_voice, ref_voice):
32
+
33
+ def seprate(mixed_voice_path, clean_voice_path, drop_down):
34
+ if drop_down == 'From mixing tool':
35
+ mixed_voice = mixed_voice_tool
36
+ else:
37
+ mixed_voice,rate = load_voice(mixed_voice_path)
38
+ clean_voice,rate = load_voice(clean_voice_path)
39
+ if clean_voice.shape[-1] < 16000*4:
40
+ n = 16000*4 // clean_voice.shape[-1] + 1
41
+ clean_voice = torch.cat([clean_voice]*n, dim=-1)
42
+ clean_voice = clean_voice[:,:16000*4]
43
+ if not model:
44
+ return None
45
+ model.to(device)
46
+ model.eval()
47
+ e = emb(clean_voice)
48
+ e_mix = emb(mixed_voice)
49
+ e = torch.cat([e,e_mix],dim=1)
50
+ mixed_voice = torchaudio.functional.resample(mixed_voice, rate, 8000)
51
+ with autocast():
52
+ with torch.no_grad():
53
+ yHat = model(
54
+ mixed_voice,
55
+ e,
56
+ )
57
+ yHat = torchaudio.functional.resample(yHat, 8000, 16000).numpy().astype('float32')
58
+ audio = gr.Audio(tuple((16000,yHat[0])),type='numpy')
59
+ return audio
60
+
61
+ def load_checkpoint(filepath):
62
+ checkpoint = torch.load(
63
+ filepath,
64
+ weights_only=True,
65
+ map_location=device,
66
+ )
67
+ model.load_state_dict(checkpoint)
68
+
69
+ with gr.Blocks() as demo:
70
+ load_checkpoint('checkpoints/concat_emb.pth')
71
+ with gr.Row():
72
+ snr = gr.Slider(label='SNR', minimum=-10, maximum=10, step=1, value=0)
73
+ with gr.Row():
74
+ with gr.Column(scale=1,min_width=200):
75
+ voice1 = gr.Audio(label='speaker 1', type='filepath')
76
+ with gr.Column(scale=1,min_width=200):
77
+ voice2 = gr.Audio(label='speaker 2', type='filepath')
78
+ with gr.Column(scale=1,min_width=200):
79
+ with gr.Row():
80
+ mixed_voice = gr.Audio(label='Mixed voice')
81
+ with gr.Row():
82
+ btn = gr.Button("Mix voices", size='sm')
83
+ btn.click(mix, inputs=[voice1, voice2, snr], outputs=mixed_voice)
84
+ with gr.Row():
85
+ choose_mix_source = gr.Label('Extract target speaker voice from mixed voice')
86
+ with gr.Row():
87
+ drop_down = gr.Dropdown(['From mixing tool', 'Upload'], label='Choose mixed voice source')
88
+ with gr.Row():
89
+ with gr.Column(scale=1,min_width=200):
90
+ with gr.Row():
91
+ mixed_voice_path = gr.Audio(label='Mixed voice', type='filepath')
92
+ with gr.Column(scale=1,min_width=200):
93
+ with gr.Row():
94
+ ref_voice_path = gr.Audio(label='reference voice', type='filepath')
95
+ with gr.Column(scale=1,min_width=200):
96
+ with gr.Row():
97
+ sep_voice = gr.Audio(label="Separate Voice")
98
+ with gr.Row():
99
+ btn = gr.Button("Separate voices", size='sm')
100
+ btn.click(seprate, inputs=[mixed_voice_path, ref_voice_path, drop_down], outputs=sep_voice)
101
+ demo.launch()
102
+
checkpoints/concat_emb.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c0d3e074f574a701ef9231d787611c47e86556b7ef7d63e1e0f6e4a4a6caa73
3
+ size 39815976
network/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __package__= "network"
network/layers/ISTFT_layer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional
4
+ from ..utils import ErrorMessageUtil
5
+ from einops import rearrange
6
+
7
+ class InverseSTFTLayer(nn.Module):
8
+ def __init__(
9
+ self,
10
+ n_fft:int = 128,
11
+ win_length: Optional[int] = None,
12
+ hop_length:int = 64,
13
+ window: str = "hann",
14
+ center: bool = True,
15
+ normalized: bool = False,
16
+ onesided: bool = True,
17
+ ):
18
+ super().__init__()
19
+ self.n_fft = n_fft
20
+ self.win_length = win_length if win_length else n_fft
21
+ self.hop_length = hop_length
22
+ self.center = center
23
+ self.normalized = normalized
24
+ self.onesided = onesided
25
+ self.window = getattr(torch,f"{window}_window")
26
+ def forward(self,input,audio_length:int):
27
+ """STFT forward function.
28
+ Args:
29
+ input: (Batch, Freq, Frames) or (Batch, Channels, Freq, Frames)
30
+ Returns:
31
+ output: (Batch, Nsamples) or (Batch, Channel, Nsample)
32
+
33
+ Notice:
34
+ input is a complex tensor
35
+ """
36
+ assert input.dim() == 4 or input.dim() == 3, ErrorMessageUtil.only_support_batch_input
37
+ batch_size = input.size(0)
38
+ multi_channel = (input.dim() == 4)
39
+ if multi_channel:
40
+ input = rearrange(input, "b c f t -> (b c) f t")
41
+ window = self.window(
42
+ self.win_length,
43
+ dtype = input.real.dtype,
44
+ device = input.device
45
+ )
46
+ istft_kwargs = dict(
47
+ n_fft=self.n_fft,
48
+ win_length=self.n_fft,
49
+ hop_length=self.hop_length,
50
+ center=self.center,
51
+ window=window,
52
+ length = audio_length,
53
+ return_complex = False
54
+ )
55
+
56
+ wave = torch.istft(input,**istft_kwargs)
57
+ if multi_channel:
58
+ wave = rearrange(wave,"(b c) l -> b c l", b = batch_size)
59
+ return wave
60
+
61
+ class ComplexTensorLayer(nn.Module):
62
+ def __init__(self):
63
+ super().__init__()
64
+ def forward(seal,input):
65
+ assert input.shape[1] == 2, ErrorMessageUtil.complex_format_convert
66
+ real = input[:,0]
67
+ imag = input[:,1]
68
+
69
+ return torch.complex(real,imag)
network/layers/STFT_Layer.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional
4
+ from ..utils import ErrorMessageUtil
5
+ from einops import rearrange
6
+
7
+ class STFTLayer(nn.Module):
8
+ def __init__(
9
+ self,
10
+ n_fft:int = 128,
11
+ win_length: Optional[int] = None,
12
+ hop_length:int = 64,
13
+ window: str = "hann",
14
+ center: bool = True,
15
+ normalized: bool = False,
16
+ onesided: bool = True,
17
+ pad_mode:str ="reflect"
18
+ ):
19
+ super().__init__()
20
+ self.n_fft = n_fft
21
+ self.win_length = win_length if win_length else n_fft
22
+ self.hop_length = hop_length
23
+ self.center = center
24
+ self.normalized = normalized
25
+ self.onesided = onesided
26
+ self.pad_mode = pad_mode
27
+ self.window = getattr(torch,f"{window}_window")
28
+ def forward(self,input:torch.Tensor):
29
+ """STFT forward function.
30
+ Args:
31
+ input: (Batch, Nsamples) or (Batch, Channel, Nsample)
32
+ Returns:
33
+ output: (Batch, Freq, Frames) or (Batch, Channels, Freq, Frames)
34
+ Notice:
35
+ output is a complex tensor
36
+ """
37
+ assert input.dim() == 2 or input.dim() == 3, ErrorMessageUtil.only_support_batch_input
38
+ batch_size = input.size(0)
39
+ multi_channel = (input.dim() == 3)
40
+ if multi_channel:
41
+ input = rearrange(input, "b c l -> (b c) l")
42
+ window = self.window(
43
+ self.win_length,
44
+ dtype = input.dtype,
45
+ device = input.device
46
+ )
47
+
48
+ stft_kwargs = dict(
49
+ n_fft=self.n_fft,
50
+ win_length=self.n_fft,
51
+ hop_length=self.hop_length,
52
+ center=self.center,
53
+ window=window,
54
+ pad_mode=self.pad_mode,
55
+ return_complex=True
56
+ )
57
+
58
+ n_pad_left = (self.n_fft - window.shape[0]) // 2
59
+ n_pad_right = self.n_fft - window.shape[0] - n_pad_left
60
+ stft_kwargs["window"] = torch.cat(
61
+ [torch.zeros(n_pad_left,device=input.device), window, torch.zeros(n_pad_right,device=input.device)], 0
62
+ )
63
+
64
+ output = torch.stft(input,**stft_kwargs)
65
+ if multi_channel:
66
+ output = rearrange(output,"(b c) f t -> b c f t", b = batch_size)
67
+ return output
network/layers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .STFT_Layer import STFTLayer
2
+ from .ISTFT_layer import *
3
+ from .film_layer import FiLMLayer
network/layers/__pycache__/ISTFT_layer.cpython-311.pyc ADDED
Binary file (4.13 kB). View file
 
network/layers/__pycache__/STFTLayer.cpython-311.pyc ADDED
Binary file (3.66 kB). View file
 
network/layers/__pycache__/STFT_Layer.cpython-311.pyc ADDED
Binary file (3.63 kB). View file
 
network/layers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (358 Bytes). View file
 
network/layers/__pycache__/film_layer.cpython-311.pyc ADDED
Binary file (2.04 kB). View file
 
network/layers/film_layer.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from einops import rearrange
3
+
4
+ class FiLMLayer(nn.Module):
5
+ def __init__(self,channels,conditional_dim=256, apply_dim = 1):
6
+ super().__init__()
7
+ self.alpha = nn.Linear(conditional_dim,channels)
8
+ self.beta = nn.Linear(conditional_dim,channels)
9
+ self.apply_dim = apply_dim
10
+ def forward(self,x,condition):
11
+ alpha = self.alpha(condition)
12
+ beta = self.beta(condition)
13
+ input = x
14
+ if self.apply_dim != 1:
15
+ input = input.transpose(1,-1)
16
+ alpha = rearrange(alpha,"b d -> b d"+" 1"*(x.dim()-alpha.dim()))
17
+ beta = rearrange(beta,"b d -> b d"+" 1"*(x.dim()-beta.dim()))
18
+ out = alpha*input+beta
19
+ if self.apply_dim != 1:
20
+ out = out.transpose(1,-1)
21
+ return out
network/models/TF_gridnet_with_condition.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from ..modules.tf_gridnet_modules import *
3
+ from ..modules.input_tranformation import STFTInput, RMSNormalizeInput
4
+ from ..modules.output_transformation import WaveGeneratorByISTFT, RMSDenormalizeOutput
5
+ from ..modules.convolution_module import SplitFeatureDeconv
6
+ from ..modules.attention import *
7
+ from .TF_gridnet import TF_Gridnet
8
+ from ..layers import FiLMLayer
9
+ from ..modules.gate_module import BandFilterGate
10
+ from einops import rearrange, repeat
11
+ import math
12
+ class TFGridFormer(nn.Module):
13
+ def __init__(
14
+ self,
15
+ n_srcs=2,
16
+ n_fft=128,
17
+ hop_length=64,
18
+ window="hann",
19
+ n_audio_channel=1,
20
+ n_layers=6,
21
+ input_kernel_size_T = 3,
22
+ input_kernel_size_F = 3,
23
+ output_kernel_size_T = 3,
24
+ output_kernel_size_F = 3,
25
+ lstm_hidden_units=192,
26
+ attn_n_head=4,
27
+ qk_output_channel=4,
28
+ emb_dim=48,
29
+ emb_ks=4,
30
+ emb_hs=1,
31
+ activation="PReLU",
32
+ eps=1.0e-5,):
33
+ super().__init__()
34
+ self.ref_input_normalize = RMSNormalizeInput((1,2),keepdim=True)
35
+ self.mix_input_normalize = RMSNormalizeInput((1,2),keepdim=True)
36
+ self.stft = STFTInput(
37
+ n_fft=n_fft,
38
+ win_length=n_fft,
39
+ hop_length=hop_length,
40
+ window=window,
41
+ )
42
+
43
+ self.mix_dimension_embedding = DimensionEmbedding(
44
+ audio_channel=n_audio_channel,
45
+ emb_dim=emb_dim,
46
+ kernel_size=(input_kernel_size_F,input_kernel_size_T),
47
+ eps=eps
48
+ )
49
+
50
+ self.ref_dimension_embedding = DimensionEmbedding(
51
+ audio_channel=n_audio_channel,
52
+ emb_dim=emb_dim,
53
+ kernel_size=(input_kernel_size_F,input_kernel_size_T),
54
+ eps=eps
55
+ )
56
+
57
+ self.istft = WaveGeneratorByISTFT(
58
+ n_fft=n_fft,
59
+ win_length=n_fft,
60
+ hop_length=hop_length,
61
+ window=window
62
+ )
63
+
64
+ self.output_denormalize = RMSDenormalizeOutput()
65
+
66
+ self.ref_encoder = TFGridnetBlock(
67
+ emb_dim=emb_dim,
68
+ kernel_size=emb_ks,
69
+ emb_hop_size=emb_hs,
70
+ hidden_channels=lstm_hidden_units,
71
+ n_head=attn_n_head,
72
+ qk_output_channel=qk_output_channel,
73
+ activation=activation,
74
+ eps=eps
75
+ )
76
+
77
+ mix_encode_layers = math.ceil(n_layers*2/3)
78
+ mix_decode_layers = n_layers - mix_encode_layers
79
+
80
+ self.mix_encoder = nn.Sequential(
81
+ *[
82
+ TFGridnetBlock(
83
+ emb_dim=emb_dim,
84
+ kernel_size=emb_ks,
85
+ emb_hop_size=emb_hs,
86
+ hidden_channels=lstm_hidden_units,
87
+ n_head=attn_n_head,
88
+ qk_output_channel=qk_output_channel,
89
+ activation=activation,
90
+ eps=eps
91
+ ) for _ in range(mix_encode_layers)
92
+ ]
93
+ )
94
+
95
+ self.split_layer = SplitFeatureDeconv(
96
+ emb_dim=emb_dim,
97
+ n_srcs=n_srcs,
98
+ kernel_size_T=output_kernel_size_T,
99
+ kernel_size_F=output_kernel_size_F,
100
+ padding_F=output_kernel_size_F//2,
101
+ padding_T=output_kernel_size_T//2
102
+ )
103
+
104
+ self.intra_frame_cross_att = IntraFrameCrossAttention(
105
+ emb_dim=emb_dim,
106
+ n_head= attn_n_head,
107
+ qk_output_channel=attn_n_head*3,
108
+ activation=activation,
109
+ eps=eps
110
+ )
111
+ self.cross_frame_cross_att = CrossFrameCrossAttention(
112
+ emb_dim=emb_dim,
113
+ n_head=attn_n_head,
114
+ qk_output_channel=qk_output_channel,
115
+ activation=activation,
116
+ eps=eps
117
+ )
118
+
119
+ self.mix_decoder = nn.Sequential(
120
+ *[
121
+ TFGridnetBlock(
122
+ emb_dim=emb_dim,
123
+ kernel_size=emb_ks,
124
+ emb_hop_size=emb_hs,
125
+ hidden_channels=lstm_hidden_units,
126
+ n_head=attn_n_head,
127
+ qk_output_channel=qk_output_channel,
128
+ activation=activation,
129
+ eps=eps
130
+ ) for _ in range(mix_decode_layers)
131
+ ]
132
+ )
133
+
134
+ self.deconv = TFGridnetDeconv(
135
+ emb_dim=emb_dim,
136
+ n_srcs=1,
137
+ kernel_size_T=output_kernel_size_T,
138
+ kernel_size_F=output_kernel_size_F,
139
+ padding_F=output_kernel_size_F//2,
140
+ padding_T=output_kernel_size_T//2
141
+ )
142
+
143
+ self.middle_deconv = TFGridnetDeconv(
144
+ emb_dim=emb_dim,
145
+ n_srcs=n_srcs,
146
+ kernel_size_T=output_kernel_size_T,
147
+ kernel_size_F=output_kernel_size_F,
148
+ padding_F=output_kernel_size_F//2,
149
+ padding_T=output_kernel_size_T//2
150
+ )
151
+
152
+ def forward(self,mix,ref,middle = False):
153
+ audio_length = mix.shape[-1]
154
+
155
+ x = mix
156
+ c = ref
157
+
158
+ if x.dim() == 2:
159
+ x = x.unsqueeze(1)
160
+ if c.dim() == 2:
161
+ c = c.unsqueeze(1)
162
+ x, std = self.mix_input_normalize(x)
163
+
164
+ x = self.stft(x)
165
+
166
+ c, _ = self.ref_input_normalize(c)
167
+
168
+ c = self.stft(c)
169
+
170
+ c = self.ref_dimension_embedding(c)
171
+
172
+ c = self.ref_encoder(c)
173
+
174
+ x = self.mix_dimension_embedding(x)
175
+
176
+ x = self.mix_encoder(x)
177
+
178
+ m = None
179
+ if middle:
180
+ m = self.middle_deconv(x)
181
+ m = rearrange(m,"B C N F T -> B N C F T")
182
+ m = self.istft(m,audio_length)
183
+ m = self.output_denormalize(m,std)
184
+
185
+
186
+ x = self.split_layer(x)
187
+
188
+ x = self.intra_frame_cross_att(c,x)
189
+
190
+ x = self.cross_frame_cross_att(c,x)
191
+
192
+ x = self.mix_decoder(x)
193
+
194
+ x = self.deconv(x)
195
+
196
+ x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
197
+
198
+ x = self.istft(x,audio_length)
199
+
200
+ x = self.output_denormalize(x,std)
201
+ if middle: return x[:,0], m
202
+ return x[:,0]
203
+
204
+ class DoubleChannelTFGridNet(TF_Gridnet):
205
+ def __init__(self,
206
+ # n_srcs=2,
207
+ n_fft=128,
208
+ hop_length=64,
209
+ window="hann",
210
+ n_audio_channel=1,
211
+ n_layers=6,
212
+ input_kernel_size_T=3,
213
+ input_kernel_size_F=3,
214
+ output_kernel_size_T=3,
215
+ output_kernel_size_F=3,
216
+ lstm_hidden_units=192,
217
+ attn_n_head=4,
218
+ qk_output_channel=4,
219
+ emb_dim=48,
220
+ emb_ks=4,
221
+ emb_hs=1,
222
+ activation="PReLU",
223
+ eps=0.00001):
224
+ super().__init__(1, n_fft, hop_length, window, n_audio_channel*2, n_layers, input_kernel_size_T, input_kernel_size_F, output_kernel_size_T, output_kernel_size_F, lstm_hidden_units, attn_n_head, qk_output_channel, emb_dim, emb_ks, emb_hs, activation, eps)
225
+ def forward(self,input,condition):
226
+ x = input
227
+ c = condition
228
+
229
+ if x.dim() == 2:
230
+ x = x.unsqueeze(1)
231
+ if c.dim() == 2:
232
+ c = c.unsqueeze(1)
233
+ tc = c.shape[-1]
234
+ tx = x.shape[-1]
235
+ if tc >= tx:
236
+ c = c[:,:,-tx:]
237
+ else:
238
+ n = math.ceil(tx/tc)
239
+ c = repeat(c,"b c t -> b c (t n)",n=n)
240
+ c = c[:,:,-tx:]
241
+
242
+ mix_with_clue = torch.cat([x,c],dim=1)
243
+ o = super().forward(mix_with_clue)
244
+ return o[:,0]
245
+
246
+
247
+ class TargetSpeakerTF(nn.Module):
248
+ def __init__(
249
+ self,
250
+ n_srcs=2,
251
+ n_fft=128,
252
+ hop_length=64,
253
+ window="hann",
254
+ n_audio_channel=1,
255
+ n_layers=6,
256
+ input_kernel_size_T = 3,
257
+ input_kernel_size_F = 3,
258
+ output_kernel_size_T = 3,
259
+ output_kernel_size_F = 3,
260
+ lstm_hidden_units=192,
261
+ attn_n_head=4,
262
+ qk_output_channel=4,
263
+ emb_dim=48,
264
+ emb_ks=4,
265
+ emb_hs=1,
266
+ activation="PReLU",
267
+ eps=1.0e-5,
268
+ conditional_dim = 256
269
+ ):
270
+ super().__init__()
271
+
272
+ self.input_normalize = RMSNormalizeInput((1,2),keepdim=True)
273
+ self.stft = STFTInput(
274
+ n_fft=n_fft,
275
+ win_length=n_fft,
276
+ hop_length=hop_length,
277
+ window=window,
278
+ )
279
+
280
+ self.istft = WaveGeneratorByISTFT(
281
+ n_fft=n_fft,
282
+ win_length=n_fft,
283
+ hop_length=hop_length,
284
+ window=window
285
+ )
286
+
287
+ self.output_denormalize = RMSDenormalizeOutput()
288
+
289
+ self.dimension_embedding = DimensionEmbedding(
290
+ audio_channel=n_audio_channel,
291
+ emb_dim=emb_dim,
292
+ kernel_size=(input_kernel_size_F,input_kernel_size_T),
293
+ eps=eps
294
+ )
295
+
296
+ self.tf_gridnet_block = nn.ModuleList(
297
+ [
298
+ TFGridnetBlock(
299
+ emb_dim=emb_dim,
300
+ kernel_size=emb_ks,
301
+ emb_hop_size=emb_hs,
302
+ hidden_channels=lstm_hidden_units,
303
+ n_head=attn_n_head,
304
+ qk_output_channel=qk_output_channel,
305
+ activation=activation,
306
+ eps=eps
307
+ ) for _ in range(n_layers)
308
+ ]
309
+ )
310
+
311
+ self.film_layer = nn.ModuleList(
312
+ [
313
+ FiLMLayer(emb_dim,conditional_dim=conditional_dim,apply_dim=1)
314
+ for _ in range(n_layers)
315
+ ]
316
+ )
317
+
318
+ self.deconv = TFGridnetDeconv(
319
+ emb_dim=emb_dim,
320
+ n_srcs=n_srcs,
321
+ kernel_size_T=output_kernel_size_T,
322
+ kernel_size_F=output_kernel_size_F,
323
+ padding_F=output_kernel_size_F//2,
324
+ padding_T=output_kernel_size_T//2
325
+ )
326
+
327
+ self.n_layers = n_layers
328
+
329
+ def forward(self,input, clue):
330
+ audio_length = input.shape[-1]
331
+
332
+ x = input
333
+
334
+ if x.dim() == 2:
335
+ x = x.unsqueeze(1)
336
+
337
+ x, std = self.input_normalize(x)
338
+
339
+ x = self.stft(x)
340
+
341
+ x = self.dimension_embedding(x)
342
+ for i in range(self.n_layers):
343
+
344
+ x = self.tf_gridnet_block[i](x)
345
+ x = self.film_layer[i](x,clue)
346
+
347
+ x = self.deconv(x)
348
+
349
+ x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
350
+
351
+ x = self.istft(x,audio_length)
352
+
353
+ x = self.output_denormalize(x,std)
354
+
355
+ return x
356
+
357
+ class DoubleChannelTargetSpeakerTF(TargetSpeakerTF):
358
+ def __init__(self,
359
+ # n_srcs=2,
360
+ n_fft=128,
361
+ hop_length=64,
362
+ window="hann",
363
+ n_audio_channel=1,
364
+ n_layers=6,
365
+ input_kernel_size_T=3,
366
+ input_kernel_size_F=3,
367
+ output_kernel_size_T=3,
368
+ output_kernel_size_F=3,
369
+ lstm_hidden_units=192,
370
+ attn_n_head=4,
371
+ qk_output_channel=4,
372
+ emb_dim=48,
373
+ emb_ks=4,
374
+ emb_hs=1,
375
+ activation="PReLU",
376
+ eps=0.00001,
377
+ conditional_dim = 256
378
+ ):
379
+ super().__init__(1, n_fft, hop_length, window, n_audio_channel*2, n_layers, input_kernel_size_T, input_kernel_size_F, output_kernel_size_T, output_kernel_size_F, lstm_hidden_units, attn_n_head, qk_output_channel, emb_dim, emb_ks, emb_hs, activation, eps, conditional_dim)
380
+ def forward(self, input, reference, embedding):
381
+ x = input
382
+ c = reference
383
+
384
+ if x.dim() == 2:
385
+ x = x.unsqueeze(1)
386
+ if c.dim() == 2:
387
+ c = c.unsqueeze(1)
388
+ tc = c.shape[-1]
389
+ tx = x.shape[-1]
390
+ if tc >= tx:
391
+ c = c[:,:,-tx:]
392
+ else:
393
+ n = math.ceil(tx/tc)
394
+ c = repeat(c,"b c t -> b c (t n)",n=n)
395
+ c = c[:,:,-tx:]
396
+
397
+ mix_with_clue = torch.cat([x,c],dim=1)
398
+ o = super().forward(mix_with_clue,embedding)
399
+ return o[:,0]
400
+
401
+ class FilterBandTFGridnet(nn.Module):
402
+ def __init__(
403
+ self,
404
+ # n_srcs=2,
405
+ n_fft=128,
406
+ hop_length=64,
407
+ window="hann",
408
+ n_audio_channel=1,
409
+ n_layers=6,
410
+ input_kernel_size_T = 3,
411
+ input_kernel_size_F = 3,
412
+ output_kernel_size_T = 3,
413
+ output_kernel_size_F = 3,
414
+ lstm_hidden_units=192,
415
+ attn_n_head=4,
416
+ qk_output_channel=4,
417
+ emb_dim=48,
418
+ emb_ks=4,
419
+ emb_hs=1,
420
+ activation="PReLU",
421
+ eps=1.0e-5,
422
+ conditional_dim = 256
423
+ ):
424
+ super().__init__()
425
+ n_freqs = n_fft//2 + 1
426
+ self.input_normalize = RMSNormalizeInput((1,2),keepdim=True)
427
+ self.stft = STFTInput(
428
+ n_fft=n_fft,
429
+ win_length=n_fft,
430
+ hop_length=hop_length,
431
+ window=window,
432
+ )
433
+
434
+ self.istft = WaveGeneratorByISTFT(
435
+ n_fft=n_fft,
436
+ win_length=n_fft,
437
+ hop_length=hop_length,
438
+ window=window
439
+ )
440
+
441
+ self.output_denormalize = RMSDenormalizeOutput()
442
+
443
+ self.dimension_embedding = DimensionEmbedding(
444
+ audio_channel=n_audio_channel,
445
+ emb_dim=emb_dim,
446
+ kernel_size=(input_kernel_size_F,input_kernel_size_T),
447
+ eps=eps
448
+ )
449
+
450
+ self.tf_gridnet_block = nn.ModuleList(
451
+ [
452
+ TFGridnetBlock(
453
+ emb_dim=emb_dim,
454
+ kernel_size=emb_ks,
455
+ emb_hop_size=emb_hs,
456
+ hidden_channels=lstm_hidden_units,
457
+ n_head=attn_n_head,
458
+ qk_output_channel=qk_output_channel,
459
+ activation=activation,
460
+ eps=eps
461
+ ) for _ in range(n_layers)
462
+ ]
463
+ )
464
+
465
+ self.filter_gen = nn.Linear(conditional_dim,emb_dim*n_freqs)
466
+ self.bias_gen = nn.Linear(conditional_dim,emb_dim*n_freqs)
467
+
468
+
469
+ self.gates = nn.ModuleList(
470
+ [
471
+ BandFilterGate(emb_dim,n_freqs)
472
+ for _ in range(n_layers)
473
+ ]
474
+ )
475
+
476
+ self.deconv = TFGridnetDeconv(
477
+ emb_dim=emb_dim,
478
+ n_srcs=1,
479
+ kernel_size_T=output_kernel_size_T,
480
+ kernel_size_F=output_kernel_size_F,
481
+ padding_F=output_kernel_size_F//2,
482
+ padding_T=output_kernel_size_T//2
483
+ )
484
+
485
+ self.n_layers = n_layers
486
+ def forward(self,input, clue):
487
+ audio_length = input.shape[-1]
488
+
489
+ x = input
490
+
491
+ if x.dim() == 2:
492
+ x = x.unsqueeze(1)
493
+
494
+ x, std = self.input_normalize(x)
495
+
496
+ x = self.stft(x)
497
+
498
+ x = self.dimension_embedding(x)
499
+
500
+ n_freqs = x.shape[-2]
501
+ f = self.filter_gen(clue)
502
+ b = self.bias_gen(clue)
503
+ f = rearrange(f,"b (d q) -> b d q 1", q = n_freqs)
504
+ b = rearrange(b,"b (d q) -> b d q 1", q = n_freqs)
505
+
506
+ for i in range(self.n_layers):
507
+
508
+ x = self.tf_gridnet_block[i](x)
509
+ x = self.gates[i](x,f,b)
510
+
511
+ x = self.deconv(x)
512
+
513
+ x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
514
+
515
+ x = self.istft(x,audio_length)
516
+
517
+ x = self.output_denormalize(x,std)
518
+
519
+ return x[:,0]
520
+
521
+ class FilterBandTFGridnetWithAttentionGate(nn.Module):
522
+ def __init__(
523
+ self,
524
+ # n_srcs=2,
525
+ n_fft=128,
526
+ hop_length=64,
527
+ window="hann",
528
+ n_audio_channel=1,
529
+ n_layers=6,
530
+ input_kernel_size_T = 3,
531
+ input_kernel_size_F = 3,
532
+ output_kernel_size_T = 3,
533
+ output_kernel_size_F = 3,
534
+ lstm_hidden_units=192,
535
+ attn_n_head=4,
536
+ qk_output_channel=4,
537
+ emb_dim=48,
538
+ emb_ks=4,
539
+ emb_hs=1,
540
+ activation="PReLU",
541
+ eps=1.0e-5,
542
+ conditional_dim = 256
543
+ ):
544
+ super().__init__()
545
+ n_freqs = n_fft//2 + 1
546
+ self.input_normalize = RMSNormalizeInput((1,2),keepdim=True)
547
+ self.stft = STFTInput(
548
+ n_fft=n_fft,
549
+ win_length=n_fft,
550
+ hop_length=hop_length,
551
+ window=window,
552
+ )
553
+
554
+ self.istft = WaveGeneratorByISTFT(
555
+ n_fft=n_fft,
556
+ win_length=n_fft,
557
+ hop_length=hop_length,
558
+ window=window
559
+ )
560
+
561
+ self.output_denormalize = RMSDenormalizeOutput()
562
+
563
+ self.dimension_embedding = DimensionEmbedding(
564
+ audio_channel=n_audio_channel,
565
+ emb_dim=emb_dim,
566
+ kernel_size=(input_kernel_size_F,input_kernel_size_T),
567
+ eps=eps
568
+ )
569
+
570
+ self.tf_gridnet_block = nn.ModuleList(
571
+ [
572
+ TFGridnetBlock(
573
+ emb_dim=emb_dim,
574
+ kernel_size=emb_ks,
575
+ emb_hop_size=emb_hs,
576
+ hidden_channels=lstm_hidden_units,
577
+ n_head=attn_n_head,
578
+ qk_output_channel=qk_output_channel,
579
+ activation=activation,
580
+ eps=eps
581
+ ) for _ in range(n_layers)
582
+ ]
583
+ )
584
+
585
+ self.query_gen = nn.Linear(conditional_dim,emb_dim*n_freqs)
586
+
587
+
588
+ self.attentions = nn.ModuleList(
589
+ [
590
+ CrossAttentionFilterV2(emb_dim)
591
+ for _ in range(n_layers)
592
+ ]
593
+ )
594
+
595
+ self.deconv = TFGridnetDeconv(
596
+ emb_dim=emb_dim,
597
+ n_srcs=1,
598
+ kernel_size_T=output_kernel_size_T,
599
+ kernel_size_F=output_kernel_size_F,
600
+ padding_F=output_kernel_size_F//2,
601
+ padding_T=output_kernel_size_T//2
602
+ )
603
+
604
+ self.n_layers = n_layers
605
+
606
+ def forward(self,input, clue):
607
+ audio_length = input.shape[-1]
608
+
609
+ x = input
610
+
611
+ if x.dim() == 2:
612
+ x = x.unsqueeze(1)
613
+
614
+ x, std = self.input_normalize(x)
615
+
616
+ x = self.stft(x)
617
+
618
+ x = self.dimension_embedding(x)
619
+
620
+ n_freqs = x.shape[-2]
621
+
622
+ q = self.query_gen(clue)
623
+ q = rearrange(q,"b (d f) -> b f d", f=n_freqs)
624
+
625
+ for i in range(self.n_layers):
626
+
627
+ x = self.tf_gridnet_block[i](x)
628
+ x = self.attentions[i](q,x)
629
+
630
+ x = self.deconv(x)
631
+
632
+ x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
633
+
634
+ x = self.istft(x,audio_length)
635
+
636
+ x = self.output_denormalize(x,std)
637
+
638
+ return x[:,0]
network/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .TF_gridnet_with_condition import *
2
+ from .embedding_model import ResemblyzerVoiceEncoder
network/models/__pycache__/TF_gridnet.cpython-311.pyc ADDED
Binary file (4 kB). View file
 
network/models/__pycache__/TF_gridnet_with_condition.cpython-311.pyc ADDED
Binary file (22.4 kB). View file
 
network/models/__pycache__/TF_gridnet_with_condition_new.cpython-311.pyc ADDED
Binary file (5.3 kB). View file
 
network/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (592 Bytes). View file
 
network/models/__pycache__/embedding_model.cpython-311.pyc ADDED
Binary file (2.14 kB). View file
 
network/models/__pycache__/new_style.cpython-311.pyc ADDED
Binary file (8.66 kB). View file
 
network/models/__pycache__/tfGrideNetLOTH.cpython-311.pyc ADDED
Binary file (5.41 kB). View file
 
network/models/embedding_model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from resemblyzer import VoiceEncoder
3
+
4
+ class ResemblyzerVoiceEncoder:
5
+ def __init__(self, device) -> None:
6
+ self.model = VoiceEncoder(device)
7
+
8
+ def __call__(self, audio: torch.Tensor):
9
+ if audio.ndimension() == 1:
10
+ return torch.tensor(self.model.embed_utterance(audio.numpy())).float().cpu()
11
+ else:
12
+ e = torch.stack([torch.tensor(self.model.embed_utterance(audio[i,:].numpy())).float().cpu()
13
+ for i in range(audio.shape[0])])
14
+ return e
network/modules/__init__.py ADDED
File without changes
network/modules/attention.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .tf_gridnet_modules import AllHeadPReLULayerNormalization4DC, LayerNormalization
5
+ from einops import rearrange, repeat
6
+ import math
7
+
8
+
9
+ class IntraFrameCrossAttention(nn.Module):
10
+ def __init__(
11
+ self,
12
+ emb_dim = 48,
13
+ n_head = 4,
14
+ qk_output_channel=12,
15
+ activation="PReLU",
16
+ eps = 1e-5
17
+ ):
18
+ super().__init__()
19
+ assert emb_dim % n_head == 0
20
+ E = qk_output_channel
21
+ self.conv_Q = nn.Conv2d(emb_dim,n_head*E,1)
22
+ self.norm_Q = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
23
+
24
+ self.conv_K = nn.Conv2d(emb_dim,n_head*E,1)
25
+ self.norm_K = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
26
+
27
+ self.conv_V = nn.Conv2d(emb_dim, emb_dim, 1)
28
+ self.norm_V = AllHeadPReLULayerNormalization4DC((n_head, emb_dim // n_head), eps=eps)
29
+
30
+ self.concat_proj = nn.Sequential(
31
+ nn.Conv2d(emb_dim,emb_dim,1),
32
+ getattr(nn,activation)(),
33
+ LayerNormalization(emb_dim, dim=-3, total_dim=4, eps=eps),
34
+ )
35
+ self.emb_dim = emb_dim
36
+ self.n_head = n_head
37
+ def forward(self,q,kv):
38
+ """
39
+ args:
40
+ query (torch.Tensor): a query for cross attention, come frome the reference encoder
41
+ [B D Q Tq]
42
+ kv (torch.Tensor): a key and value for cross attention, come frome the output of feature split
43
+ [B nSrc D Q Tkv]
44
+ output:
45
+ output: (torch.Tensor):[B D Q Tkv]
46
+ """
47
+
48
+ B, D, freq, Tq = q.shape
49
+
50
+ _, nSrc, _, _, Tkv = kv.shape
51
+ if Tq >= Tkv:
52
+ q = q[:,:,:,-Tkv:]
53
+ else:
54
+ r = math.ceil(Tkv/Tq)
55
+ q = repeat(q,"B D Q T -> B D Q (T r)", r = r)
56
+ q = q[:,:,:,-Tkv:]
57
+ query = rearrange(q,"B D Q T -> B D T Q")
58
+ kvInput = rearrange(kv,"B n D Q T -> B D T (n Q)")
59
+
60
+ Q = self.norm_Q(self.conv_Q(query)) # [B, n_head, C, T, Q]
61
+ K = self.norm_K(self.conv_K(kvInput)) # [B, n_head, C, T, Q*nSrc]
62
+ V = self.norm_V(self.conv_V(kvInput))
63
+
64
+ Q = rearrange(Q, "B H C T Q -> (B H T) Q C")
65
+ K = rearrange(K, "B H C T Q -> (B H T) C Q").contiguous()
66
+ _, n_head, channel, _, _ = V.shape
67
+ V = rearrange(V, "B H C T Q -> (B H T) Q C")
68
+
69
+ emb_dim = Q.shape[-1]
70
+ qkT = torch.matmul(Q, K) / (emb_dim**0.5)
71
+ qkT = F.softmax(qkT,dim=2)
72
+
73
+ att = torch.matmul(qkT,V)
74
+ att = rearrange(att, "(B H T) Q C -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = B, T=Tkv)
75
+ att = self.concat_proj(att)
76
+
77
+ out = att + query
78
+ out = rearrange(out, "B C T Q -> B C Q T")
79
+ return out
80
+
81
+
82
+ class CrossFrameCrossAttention(nn.Module):
83
+ def __init__(
84
+ self,
85
+ emb_dim = 48,
86
+ n_head=4,
87
+ qk_output_channel=4,
88
+ activation="PReLU",
89
+ eps = 1e-5
90
+
91
+ ):
92
+ super().__init__()
93
+ assert emb_dim % n_head == 0
94
+ E = qk_output_channel
95
+ self.conv_Q = nn.Conv2d(emb_dim,n_head*E,1)
96
+ self.norm_Q = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
97
+
98
+ self.conv_K = nn.Conv2d(emb_dim,n_head*E,1)
99
+ self.norm_K = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
100
+
101
+ self.conv_V = nn.Conv2d(emb_dim, emb_dim, 1)
102
+ self.norm_V = AllHeadPReLULayerNormalization4DC((n_head, emb_dim // n_head), eps=eps)
103
+
104
+ self.concat_proj = nn.Sequential(
105
+ nn.Conv2d(emb_dim,emb_dim,1),
106
+ getattr(nn,activation)(),
107
+ LayerNormalization(emb_dim, dim=-3, total_dim=4, eps=eps),
108
+ )
109
+ self.emb_dim = emb_dim
110
+ self.n_head = n_head
111
+ def forward(self,q,kv):
112
+ """
113
+ args:
114
+ query (torch.Tensor): a query for cross attention, come frome the reference encoder
115
+ [B D Q Tq]
116
+ kv (torch.Tensor): a key and value for cross attention, come frome the output of feature split
117
+ [B D Q Tkv]
118
+ output:
119
+ output: (torch.Tensor):[B D Q Tkv]
120
+ """
121
+ Tq = q.shape[-1]
122
+ Tkv = kv.shape[-1]
123
+ if Tq >= Tkv:
124
+ q = q[:,:,:,-Tkv:]
125
+ else:
126
+ r = math.ceil(Tkv/Tq)
127
+ q = repeat(q,"B D Q T -> B D Q (T r)", r = r)
128
+ q = q[:,:,:,-Tkv:]
129
+
130
+ input = rearrange(q,"B C Q T -> B C T Q")
131
+ kvInput = rearrange(kv,"B C Q T -> B C T Q")
132
+
133
+ Q = self.norm_Q(self.conv_Q(input)) # [B, n_head, C, T, Q]
134
+ K = self.norm_K(self.conv_K(kvInput))
135
+ V = self.norm_V(self.conv_V(kvInput))
136
+
137
+ Q = rearrange(Q, "B H C T Q -> (B H) T (C Q)")
138
+ K = rearrange(K, "B H C T Q -> (B H) (C Q) T").contiguous()
139
+ batch, n_head, channel, frame, freq = V.shape
140
+ V = rearrange(V, "B H C T Q -> (B H) T (C Q)")
141
+ emb_dim = Q.shape[-1]
142
+ qkT = torch.matmul(Q, K) / (emb_dim**0.5)
143
+ qkT = F.softmax(qkT,dim=2)
144
+ att = torch.matmul(qkT,V)
145
+ att = rearrange(att, "(B H) T (C Q) -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = batch, T=frame)
146
+ att = self.concat_proj(att)
147
+ out = att + input
148
+ out = rearrange(out, "B C T Q -> B C Q T")
149
+ return out
150
+
151
+ class CrossAttentionFilter(nn.Module):
152
+ def __init__(self, emb_dim = 48) -> None:
153
+ super().__init__()
154
+ self.emb_dim = emb_dim
155
+
156
+ def forward(self, q, k, v):
157
+ """
158
+ Args:
159
+ q (torch.Tensor): from the provious layer, [B D F T]
160
+ k (torch.Tensor): from the speaker embedidng encoder, [B D]
161
+ v (torch.Tensor): from the speaker embedidng encoder, [B D]
162
+ """
163
+
164
+ B, D, _, T = q.shape
165
+
166
+ q = rearrange(q, "B D F T -> (B T) F D")
167
+ k = repeat(k, "B D -> (B T) D 1", T = T)
168
+ v = repeat(v, "B D -> (B T) 1 D", T = T)
169
+
170
+ qkT = torch.matmul(q, k)/(D**0.5) # [(B T) F 1]
171
+ qkT = F.softmax(qkT, dim=-1)
172
+ att = torch.matmul(qkT, v) # [(B T) F D]
173
+ att = rearrange(att, "(B T) F D -> B D F T", B = B, T = T)
174
+ return att
175
+
176
+ class CrossAttentionFilterV2(nn.Module):
177
+ def __init__(self, emb_dim = 48) -> None:
178
+ super().__init__()
179
+ self.emb_dim = emb_dim
180
+ def forward(self,q, kv):
181
+ """
182
+ Args:
183
+ q: torch.Tensor, [B F D] a query for cross attention, come from the reference encoder (speaker embedding)
184
+ kv: torch.Tensor, [B D F T] a key and value for cross attention, come from the output of previous layer (TF gridnet)
185
+ """
186
+
187
+ B, D, _, T = kv.shape
188
+
189
+ Q = repeat(q, "B F D -> (B T) F D", T = T)
190
+ K = rearrange(kv, "B D F T -> (B T) D F")
191
+ V = rearrange(kv, "B D F T -> (B T) F D")
192
+
193
+ qkT = torch.matmul(Q,K)/(D**0.5) #[(B T) F F]
194
+ qkT = F.softmax(qkT, dim=-1)
195
+ att = torch.matmul(qkT, V) # [(B T) F D]
196
+ att = rearrange(att, "(B T) F D -> B D F T", B = B, T = T)
197
+ return att
network/modules/gate_module.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ class BandFilterGate(nn.Module):
6
+ def __init__(self,emb_dim=48, n_freqs = 65):
7
+ super().__init__()
8
+ self.alpha = nn.Parameter(torch.empty(1,emb_dim,n_freqs,1).to(torch.float32))
9
+ self.beta = nn.Parameter(torch.empty(1,emb_dim,n_freqs,1).to(torch.float32))
10
+ nn.init.xavier_normal_(self.alpha)
11
+ nn.init.xavier_normal_(self.beta)
12
+ def forward(self,input,filters,bias):
13
+ f = F.sigmoid(self.alpha*filters)
14
+ b = F.tanh(self.beta*bias)
15
+ return f*input + b
16
+
network/modules/input_tranformation.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..layers import STFTLayer
4
+ from ..utils import STFT_transform_type_enum
5
+ from typing import Iterable
6
+
7
+ class SimpleConv1DInput(nn.Module):
8
+ def __init__(self, in_channels, out_channels, kernel, stride = 1,
9
+ padding = 0, dilation=1, groups=1, bias=True,
10
+ padding_mode='zeros', activation:str = 'ReLU'):
11
+ super().__init__()
12
+ activation = getattr(nn,activation)
13
+ self.model = nn.Sequential(
14
+ nn.Conv1d(in_channels,out_channels,kernel,stride,padding,dilation,groups,bias,padding_mode),
15
+ activation()
16
+ )
17
+ def forward(self,input):
18
+ return self.model(input)
19
+
20
+ class STFTInput(nn.Module):
21
+ def __init__(
22
+ self,
23
+ n_fft: int = 128,
24
+ win_length: int = None,
25
+ hop_length: int = 64,
26
+ window="hann",
27
+ center: bool = True,
28
+ normalized: bool = False,
29
+ onesided: bool = True,
30
+ spec_transform_type: str = None,
31
+ spec_factor: float = 0.15,
32
+ spec_abs_exponent: float = 0.5,
33
+ ):
34
+ super().__init__()
35
+ self.stft = STFTLayer(
36
+ n_fft,
37
+ win_length,
38
+ hop_length,
39
+ window,
40
+ center,
41
+ normalized,
42
+ onesided
43
+ )
44
+
45
+ self.spec_transform_type = spec_transform_type
46
+ self.spec_factor = spec_factor
47
+ self.spec_abs_exponent = spec_abs_exponent
48
+
49
+ self.spec_transform = lambda spec: spec
50
+ if self.spec_transform_type == STFT_transform_type_enum.exponent:
51
+ self.spec_transform = lambda spec: spec.abs() ** self.spec_abs_exponent * torch.exp(1j * spec.angle())
52
+ elif self.spec_transform_type == STFT_transform_type_enum.log:
53
+ self.spec_transform = lambda spec: torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle()) * self.spec_factor
54
+
55
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
56
+ def forward(self,input):
57
+ """
58
+ Notice that, in pytorch, the STFT does not support quantize 16 bit float, so this function
59
+ is decorated with @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
60
+ Args:
61
+ input (torch.Tensor): signal [Batch, Nsamples] or [Batch,channel,Nsamples]
62
+ ouputs:
63
+ spectrum (torch.Tensor): float tensor perform the spectrum with 2 channel, the first channel
64
+ is real part of spectrum, the second channel is the imaginary part of spectrum
65
+ [Batch, 2, F, T] or [Batch, 2 * channel, F, T]
66
+ """
67
+
68
+ spectrum = self.stft(input.float())
69
+ spectrum = self.spec_transform(spectrum)
70
+
71
+ re = spectrum.real
72
+ im = spectrum.imag
73
+
74
+ if input.dim() == 2:
75
+ re = re.unsqueeze(1)
76
+ im = im.unsqueeze(1)
77
+
78
+ if input.dtype in (torch.float16, torch.bfloat16):
79
+ re = re.to(dtype=input.dtype)
80
+ im = im.to(dtype=input.dtype)
81
+
82
+ return torch.cat([re,im],dim=1)
83
+
84
+ class RMSNormalizeInput(nn.Module):
85
+ def __init__(self, dim: Iterable[int], keepdim:bool = True) -> None:
86
+ super().__init__()
87
+ self.dim = dim
88
+ self.keepdim = keepdim
89
+ def forward(self,input):
90
+ std = torch.std(input,dim=self.dim,keepdim=self.keepdim)
91
+ output = input/std
92
+ return output, std
network/modules/output_transformation.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..layers import InverseSTFTLayer, ComplexTensorLayer
4
+ from typing import Iterable, Optional
5
+
6
+ class WaveGeneratorByISTFT(nn.Module):
7
+ def __init__(
8
+ self,
9
+ n_fft:int = 128,
10
+ win_length: Optional[int] = None,
11
+ hop_length:int = 64,
12
+ window: str = "hann",
13
+ center: bool = True,
14
+ normalized: bool = False,
15
+ onesided: bool = True
16
+ ) -> None:
17
+ super().__init__()
18
+ self.istft = InverseSTFTLayer(
19
+ n_fft,
20
+ win_length,
21
+ hop_length,
22
+ window,
23
+ center,
24
+ normalized,
25
+ onesided
26
+ )
27
+
28
+ self.float_to_complex = ComplexTensorLayer()
29
+
30
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
31
+ def forward(self,input,length:int=None):
32
+ x = input
33
+ if input.dtype in (torch.float16, torch.bfloat16):
34
+ x = input.float()
35
+ if x.dtype in (torch.float32,):
36
+ x = self.float_to_complex(x)
37
+
38
+ wav = self.istft(x,length)
39
+ wav = wav.to(dtype=input.dtype)
40
+
41
+ return wav
42
+
43
+ class RMSDenormalizeOutput(nn.Module):
44
+ def __init__(self) -> None:
45
+ super().__init__()
46
+ def forward(self,input,std):
47
+ return input*std
network/modules/sequence_embed.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from .tf_gridnet_modules import CrossFrameSelfAttention
6
+
7
+ class SequenceEmbed(nn.Module):
8
+ def __init__(
9
+ self,
10
+ emb_dim: int = 48,
11
+ n_fft: int = 128,
12
+ hidden_size: int = 192,
13
+ kernel_T: int = 5,
14
+ kernel_F: int = 5,
15
+ ):
16
+ super().__init__()
17
+
18
+ self.n_freqs = n_fft // 2 + 1
19
+ self.emb_dim = emb_dim
20
+
21
+ self.conv = nn.Sequential(
22
+ nn.Conv2d(emb_dim*2,emb_dim*2,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2),groups=emb_dim*2),
23
+ nn.PReLU(),
24
+ nn.Conv2d(emb_dim*2,emb_dim*2,1),
25
+ nn.PReLU(),
26
+ nn.Conv2d(emb_dim*2,emb_dim*2,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2),groups=emb_dim*2),
27
+ nn.PReLU(),
28
+ nn.Conv2d(emb_dim*2,emb_dim,1),
29
+ nn.PReLU(),
30
+ )
31
+
32
+ self.linear_pre = nn.Conv1d(emb_dim*self.n_freqs,hidden_size,1)
33
+
34
+ self.lstm = nn.LSTM(
35
+ hidden_size,hidden_size,1,batch_first=True,bidirectional=True
36
+ )
37
+
38
+ self.linear = nn.Linear(hidden_size*2,emb_dim*self.n_freqs)
39
+
40
+ self.filter_gen = nn.Conv1d(emb_dim,emb_dim,1)
41
+ self.bias_gen = nn.Conv1d(emb_dim,emb_dim,1)
42
+ def forward(self,x,ref):
43
+ """
44
+ Args:
45
+ x: (B, D, F, T) input tensor from prevous layer
46
+ ref: (B, D, F, T) embedding tensor previous layer
47
+ """
48
+ B, D, n_freq, T = x.shape
49
+ input = torch.cat([x,ref],dim=1)
50
+ input = self.conv(input)
51
+ input = rearrange(input,'B D F T -> B (D F) T')
52
+ input = self.linear_pre(input)
53
+ input = rearrange(input,'B C T -> B T C')
54
+ rnn , _ = self.lstm(input)
55
+ feature = rnn[:,0]+rnn[:,-1] # (B, 2*Hidden)
56
+ feature = self.linear(feature) # (B, D*F)
57
+ feature = rearrange(feature,'B (D F) -> B D F',D=D,F=n_freq)
58
+ f = self.filter_gen(feature)
59
+ b = self.bias_gen(feature)
60
+
61
+ return f.unsqueeze(-1), b.unsqueeze(-1)
62
+
63
+ class CrossFrameCrossAttention(CrossFrameSelfAttention):
64
+ def __init__(self, emb_dim=48, n_freqs=65, n_head=4, qk_output_channel=4, activation="PReLU", eps=0.00001):
65
+ super().__init__(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps)
66
+
67
+ def forward(self, q, kv):
68
+ """
69
+ Args:
70
+ q: (B, D, F, T) query tensor
71
+ kv: (B, D, F, T) key and value tensor
72
+ """
73
+
74
+ input_q = rearrange(q,"B C Q T -> B C T Q")
75
+ input_kv = rearrange(kv,"B C Q T -> B C T Q")
76
+
77
+ Q = self.norm_Q(self.conv_Q(input_q))
78
+ K = self.norm_K(self.conv_K(input_kv))
79
+ V = self.norm_V(self.conv_V(input_kv))
80
+ Q = rearrange(Q, "B H C T Q -> (B H) T (C Q)")
81
+ K = rearrange(K, "B H C T Q -> (B H) (C Q) T").contiguous()
82
+ batch, n_head, channel, frame, freq = V.shape
83
+ V = rearrange(V, "B H C T Q -> (B H) T (C Q)")
84
+ emb_dim = Q.shape[-1]
85
+ qkT = torch.matmul(Q, K) / (emb_dim**0.5)
86
+ qkT = F.softmax(qkT,dim=2)
87
+ att = torch.matmul(qkT,V)
88
+ att = rearrange(att, "(B H) T (C Q) -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = batch, T=frame)
89
+ att = self.concat_proj(att)
90
+ out = att + input_q
91
+ out = rearrange(out, "B C T Q -> B C Q T")
92
+ return out
93
+
94
+ class MutualAttention(nn.Module):
95
+ def __init__(self,kernel_T=5, kernel_F=5 ,emb_dim=48, n_freqs=65, n_head=4, qk_output_channel=4, activation="PReLU", eps=0.00001):
96
+ super().__init__()
97
+
98
+ self.ref_att = CrossFrameCrossAttention(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps)
99
+ self.tar_att = CrossFrameCrossAttention(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps)
100
+
101
+ self.mt_conv = nn.Sequential(
102
+ nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
103
+ nn.PReLU(),
104
+ nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
105
+ nn.Sigmoid()
106
+ )
107
+
108
+ self.mr_conv = nn.Sequential(
109
+ nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
110
+ nn.PReLU(),
111
+ nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
112
+ nn.Sigmoid()
113
+ )
114
+
115
+ self.mtr_conv = nn.Sequential(
116
+ nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
117
+ nn.PReLU(),
118
+ nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
119
+ nn.PReLU()
120
+ )
121
+ def forward(self,tar,ref):
122
+ """
123
+ Args:
124
+ ref: (B, D, F, T) reference tensor
125
+ tar: (B, D, F, T) target tensor
126
+ """
127
+
128
+ mr = self.ref_att(ref,tar)
129
+ mt = self.tar_att(tar,ref)
130
+
131
+ mrt = mr + mt
132
+
133
+ mr = self.mr_conv(mr)
134
+ mt = self.mt_conv(mt)
135
+ mrt_o = self.mtr_conv(mrt)
136
+
137
+ o = mr*mt*mrt_o + mrt
138
+ return o
network/modules/split_modules.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+
6
+ class DimensionDimAttention(nn.Module):
7
+ def __init__(
8
+ self,
9
+ emb_dim: int = 48,
10
+ kernel_size: int = (7,7),
11
+ dilation: int = 2,
12
+ ) -> None:
13
+ super().__init__()
14
+ self.emb_dim = emb_dim
15
+
16
+ self.attn = nn.Sequential(
17
+ nn.Conv2d(2*emb_dim,2*emb_dim,1),
18
+ nn.GELU(),
19
+ nn.Conv2d(2*emb_dim,2*emb_dim,kernel_size=kernel_size,groups=2*emb_dim,padding="same"),
20
+ nn.PReLU(),
21
+ nn.Conv2d(2*emb_dim,2*emb_dim,kernel_size=kernel_size,dilation=dilation,groups=2*emb_dim,padding="same"),
22
+ nn.PReLU(),
23
+ nn.Conv2d(2*emb_dim,emb_dim,1),
24
+ nn.Sigmoid()
25
+ )
26
+
27
+ self.transform = nn.Sequential(
28
+ nn.Conv2d(emb_dim,emb_dim,1),
29
+ nn.GELU(),
30
+ nn.Conv2d(emb_dim,emb_dim,kernel_size=kernel_size,groups=emb_dim,padding="same"),
31
+ nn.PReLU(),
32
+ nn.Conv2d(emb_dim,emb_dim,kernel_size=kernel_size,dilation=dilation,groups=emb_dim,padding="same"),
33
+ nn.PReLU(),
34
+ nn.Conv2d(emb_dim,emb_dim,1)
35
+ )
36
+
37
+ def forward(self,x,e):
38
+ """
39
+ Args:
40
+ x: (B, D, F, T) input tensor from privous layer
41
+ e: (B, D, F) embedding after reshape
42
+ """
43
+
44
+ T = x.shape[-1]
45
+ emb = repeat(e, 'B D F -> B D F T', T=T)
46
+ att = torch.cat([x,emb],dim=1)
47
+
48
+ i = self.transform(x)
49
+ att = self.attn(att)
50
+ return i*att
51
+
52
+ class FDAttention(nn.Module):
53
+ def __init__(
54
+ self
55
+ ) -> None:
56
+ super().__init__()
57
+ def forward(self,x,e):
58
+ """
59
+ Args:
60
+
61
+ x: (B, D, F, T) input tensor from privous layer (for k and v)
62
+ e: (B, D, F) embedding after reshape (for q)
63
+ """
64
+
65
+ _,D,n_freq,T = x.shape
66
+ q = repeat(e, 'B D F -> B T (D F)', T=T)
67
+ k = rearrange(x, 'B D F T -> B (D F) T')
68
+ v = rearrange(x, 'B D F T -> B T (D F)')
69
+
70
+ q = self.positional_encoding(q)
71
+ qkT = torch.matmul(q,k)/((D*n_freq)**0.5)
72
+ qkT = F.softmax(qkT,dim=-1)
73
+ att = torch.matmul(qkT,v)
74
+
75
+ att = rearrange(att, 'B T (D F) -> B D F T', D=D, F=n_freq)
76
+ return att
77
+
78
+ def positional_encoding(self, x):
79
+ """
80
+ Args:
81
+ x: (B, T, D) input to add positional encoding
82
+ """
83
+ B, T, D = x.shape
84
+ pos = torch.arange(T, device=x.device).unsqueeze(1)
85
+ div_term = torch.exp(torch.arange(0, D, 2, device=x.device) * (-torch.log(torch.tensor(10000.0)) / D))
86
+
87
+ pos_enc = torch.zeros_like(x)
88
+ pos_enc[:, :, 0::2] = torch.sin(pos * div_term)
89
+ pos_enc[:, :, 1::2] = torch.cos(pos * div_term)
90
+ return x + pos_enc
91
+
92
+ class SplitModule(nn.Module):
93
+ def __init__(
94
+ self,
95
+ emb_dim: int = 48,
96
+ condition_dim: int = 256,
97
+ n_fft: int = 128,
98
+ ) -> None:
99
+ super().__init__()
100
+ self.emb_dim = emb_dim
101
+ self.condition_dim = condition_dim
102
+ n_freq = n_fft // 2 + 1
103
+ self.n_freqs = n_freq
104
+
105
+ self.alpha = nn.Parameter(torch.empty(1,emb_dim,self.n_freqs).to(torch.float32))
106
+ self.beta = nn.Parameter(torch.empty(1,emb_dim,self.n_freqs,1).to(torch.float32))
107
+ self.d_att = DimensionDimAttention(emb_dim=emb_dim)
108
+
109
+ # self.f_att = FDAttention()
110
+ def forward(self,input,emb):
111
+ """
112
+ Args:
113
+ input: (B, D, F, T) input tensor
114
+ emb: (B, D, F) embedding after reshape
115
+ """
116
+ e = F.tanh(emb*self.alpha)
117
+
118
+ x = self.d_att(input,e)
119
+ # x = self.f_att(x,e)
120
+ x = x*F.sigmoid(self.beta*emb.unsqueeze(-1))
121
+ return x
network/modules/tf_gridnet_modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .dimension_embedding import DimensionEmbedding
2
+ from .tf_gridnet_block import *
3
+ from .deconv import *
network/modules/tf_gridnet_modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (369 Bytes). View file
 
network/modules/tf_gridnet_modules/__pycache__/deconv.cpython-311.pyc ADDED
Binary file (1.6 kB). View file
 
network/modules/tf_gridnet_modules/__pycache__/dimension_embedding.cpython-311.pyc ADDED
Binary file (1.63 kB). View file
 
network/modules/tf_gridnet_modules/__pycache__/tf_gridnet_block.cpython-311.pyc ADDED
Binary file (15.7 kB). View file
 
network/modules/tf_gridnet_modules/deconv.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from einops import rearrange
3
+
4
+ class TFGridnetDeconv(nn.Module):
5
+ def __init__(
6
+ self,
7
+ emb_dim = 48,
8
+ n_srcs = 2,
9
+ kernel_size_T = 3,
10
+ kernel_size_F = 3,
11
+ padding_T = 1,
12
+ padding_F = 1,
13
+ ) -> None:
14
+ super().__init__()
15
+ self.n_srcs = n_srcs
16
+ self.deconv = nn.ConvTranspose2d(emb_dim, n_srcs * 2, (kernel_size_F,kernel_size_T), padding=(padding_F,padding_T))
17
+ def forward(self,input):
18
+ output = self.deconv(input)
19
+ output = rearrange(output,"B (N C) F T -> B C N F T", C=self.n_srcs)
20
+ return output
network/modules/tf_gridnet_modules/dimension_embedding.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from typing import Tuple
3
+ class DimensionEmbedding(nn.Module):
4
+ def __init__(
5
+ self, audio_channel:int = 1,emb_dim:int = 48,
6
+ kernel_size: Tuple[int,int] = (3,3),
7
+ padding = "same",eps=1.0e-5
8
+ ) -> None:
9
+ super().__init__()
10
+ self.emb = nn.Sequential(
11
+ nn.Conv2d(2*audio_channel, emb_dim, kernel_size,padding=padding),
12
+ nn.GroupNorm(1,emb_dim,eps=eps)
13
+ )
14
+ def forward(self,input):
15
+ return self.emb(input)
network/modules/tf_gridnet_modules/tf_gridnet_block.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ import math
6
+
7
+ if hasattr(torch, "bfloat16"):
8
+ HALF_PRECISION_DTYPES = (torch.float16, torch.bfloat16)
9
+ else:
10
+ HALF_PRECISION_DTYPES = (torch.float16,)
11
+
12
+ class IntraAndInterBandModule(nn.Module):
13
+ def __init__(
14
+ self, emb_dim:int = 48,
15
+ kernel_size:int = 4,
16
+ emb_hop_size:int = 1,
17
+ hidden_channels:int = 192,
18
+ eps = 1e-5
19
+ ) -> None:
20
+ super().__init__()
21
+ self.emb_dim = emb_dim
22
+ self.emb_hs = emb_hop_size
23
+ self.kernel_size = kernel_size
24
+ in_channels = emb_dim * kernel_size
25
+
26
+ self.intra_norm = nn.LayerNorm(emb_dim,eps=eps)
27
+
28
+ self.intra_lstm = nn.LSTM(
29
+ in_channels,hidden_channels,1,batch_first=True,bidirectional=True
30
+ )
31
+
32
+ if kernel_size == emb_hop_size:
33
+ self.intra_linear = nn.Linear(hidden_channels*2, in_channels)
34
+ else:
35
+ self.intra_linear = nn.ConvTranspose1d(hidden_channels*2, emb_dim,kernel_size ,emb_hop_size)
36
+
37
+ self.inter_norm = nn.LayerNorm(emb_dim, eps=eps)
38
+ self.inter_lstm = nn.LSTM(
39
+ in_channels,hidden_channels,1,batch_first=True,bidirectional=True
40
+ )
41
+
42
+ if kernel_size == emb_hop_size:
43
+ self.inter_linear = nn.Linear(hidden_channels*2, in_channels)
44
+ else:
45
+ self.inter_linear = nn.ConvTranspose1d(hidden_channels*2, emb_dim,kernel_size ,emb_hop_size)
46
+ def forward(self,x):
47
+ """
48
+ Args:
49
+ input (torch.Tensor): [B C Q T]
50
+ output:
51
+ ouput (torch.Tensor): [B C Q T]
52
+ """
53
+ B, C, old_Q, old_T = x.shape
54
+
55
+ padding = self.kernel_size - self.emb_hs
56
+
57
+ T = (
58
+ math.ceil((old_T + 2 * padding - self.kernel_size) / self.emb_hs) * self.emb_hs
59
+ + self.kernel_size
60
+ )
61
+ Q = (
62
+ math.ceil((old_Q + 2 * padding - self.kernel_size) / self.emb_hs) * self.emb_hs
63
+ + self.kernel_size
64
+ )
65
+
66
+ input = rearrange(x, "B C Q T -> B T Q C")
67
+ input = F.pad(input, (0, 0, padding, Q - old_Q - padding, padding, T - old_T - padding))
68
+ intra_rnn = self.intra_norm(input)
69
+ if self.kernel_size == self.emb_hs:
70
+ intra_rnn = intra_rnn.view([B * T, -1, self.kernel_size * C])
71
+ intra_rnn, _ = self.intra_lstm(intra_rnn)
72
+ intra_rnn = self.intra_linear(intra_rnn)
73
+ intra_rnn = intra_rnn.view([B, T, Q, C])
74
+ else:
75
+ intra_rnn = rearrange(intra_rnn,"B T Q C -> (B T) C Q")
76
+ intra_rnn = F.unfold(
77
+ intra_rnn[...,None],(self.kernel_size,1),stride=(self.emb_hs,1)
78
+ )
79
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C*I]
80
+ intra_rnn, _ = self.intra_lstm(intra_rnn)
81
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
82
+ intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
83
+ intra_rnn = intra_rnn.view([B, T, C, Q])
84
+ intra_rnn = intra_rnn.transpose(-2, -1) # [B, T, Q, C]
85
+ intra_rnn = intra_rnn + input
86
+ inter_input = rearrange(intra_rnn, "B T Q C -> B Q T C")
87
+ inter_rnn = self.inter_norm(inter_input)
88
+ if self.kernel_size == self.emb_hs:
89
+ inter_rnn = inter_rnn.view([B * Q, -1, self.kernel_size * C])
90
+ inter_rnn, _ = self.inter_lstm(inter_rnn)
91
+ inter_rnn = self.inter_linear(intra_rnn)
92
+ inter_rnn = inter_rnn.view([B, Q, T, C])
93
+ else:
94
+ inter_rnn = rearrange(inter_rnn,"B Q T C -> (B Q) C T")
95
+ inter_rnn = F.unfold(
96
+ inter_rnn[...,None],(self.kernel_size,1),stride=(self.emb_hs,1)
97
+ )
98
+ inter_rnn = inter_rnn.transpose(1, 2) # [BQ, -1, C*I]
99
+ inter_rnn,_ = self.inter_lstm(inter_rnn)
100
+ inter_rnn = inter_rnn.transpose(1, 2) # [BQ, H, -1]
101
+ inter_rnn = self.inter_linear(inter_rnn) # [BQ, C, T]
102
+ inter_rnn = inter_rnn.view([B, Q, C, T])
103
+ inter_rnn = inter_rnn.transpose(-2, -1) # [B, Q, T, C]
104
+ inter_rnn = inter_rnn + inter_input
105
+
106
+ inter_rnn = rearrange(inter_rnn,"B Q T C -> B C Q T")
107
+ inter_rnn = inter_rnn[..., padding : padding + old_Q, padding : padding + old_T]
108
+
109
+ return inter_rnn
110
+
111
+ class LayerNormalization(nn.Module):
112
+ def __init__(self, input_dim, dim=1, total_dim=4, eps=1e-5):
113
+ super().__init__()
114
+ self.dim = dim if dim >= 0 else total_dim + dim
115
+ param_size = [1 if ii != self.dim else input_dim for ii in range(total_dim)]
116
+ self.gamma = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
117
+ self.beta = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
118
+ nn.init.ones_(self.gamma)
119
+ nn.init.zeros_(self.beta)
120
+ self.eps = eps
121
+
122
+ @torch.cuda.amp.autocast(enabled=False)
123
+ def forward(self, x):
124
+ if x.ndim - 1 < self.dim:
125
+ raise ValueError(
126
+ f"Expect x to have {self.dim + 1} dimensions, but got {x.ndim}"
127
+ )
128
+ if x.dtype in HALF_PRECISION_DTYPES:
129
+ dtype = x.dtype
130
+ x = x.float()
131
+ else:
132
+ dtype = None
133
+ mu_ = x.mean(dim=self.dim, keepdim=True)
134
+ std_ = torch.sqrt(x.var(dim=self.dim, unbiased=False, keepdim=True) + self.eps)
135
+ x_hat = ((x - mu_) / std_) * self.gamma + self.beta
136
+ return x_hat.to(dtype=dtype) if dtype else x_hat
137
+
138
+ class AllHeadPReLULayerNormalization4DC(nn.Module):
139
+ def __init__(self, input_dimension, eps=1e-5):
140
+ super().__init__()
141
+ assert len(input_dimension) == 2, input_dimension
142
+ H, E = input_dimension
143
+ param_size = [1, H, E, 1, 1]
144
+ self.gamma = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
145
+ self.beta = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
146
+ nn.init.ones_(self.gamma)
147
+ nn.init.zeros_(self.beta)
148
+ self.act = nn.PReLU(num_parameters=H, init=0.25)
149
+ self.eps = eps
150
+ self.H = H
151
+ self.E = E
152
+
153
+ def forward(self, x):
154
+ assert x.ndim == 4
155
+ B, _, T, F = x.shape
156
+ x = x.view([B, self.H, self.E, T, F])
157
+ x = self.act(x) # [B,H,E,T,F]
158
+ stat_dim = (2,)
159
+ mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,H,1,T,1]
160
+ std_ = torch.sqrt(
161
+ x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
162
+ ) # [B,H,1,T,1]
163
+ x = ((x - mu_) / std_) * self.gamma + self.beta # [B,H,E,T,F]
164
+ return x
165
+
166
+ class CrossFrameSelfAttention(nn.Module):
167
+ def __init__(
168
+ self,
169
+ emb_dim = 48,
170
+ n_freqs = 65,
171
+ n_head=4,
172
+ qk_output_channel=4,
173
+ activation="PReLU",
174
+ eps = 1e-5
175
+
176
+ ):
177
+ super().__init__()
178
+ assert emb_dim % n_head == 0
179
+ E = qk_output_channel
180
+ self.conv_Q = nn.Conv2d(emb_dim,n_head*E,1)
181
+ self.norm_Q = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
182
+
183
+ self.conv_K = nn.Conv2d(emb_dim,n_head*E,1)
184
+ self.norm_K = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
185
+
186
+ self.conv_V = nn.Conv2d(emb_dim, emb_dim, 1)
187
+ self.norm_V = AllHeadPReLULayerNormalization4DC((n_head, emb_dim // n_head), eps=eps)
188
+
189
+ self.concat_proj = nn.Sequential(
190
+ nn.Conv2d(emb_dim,emb_dim,1),
191
+ getattr(nn,activation)(),
192
+ LayerNormalization(emb_dim, dim=-3, total_dim=4, eps=eps),
193
+ )
194
+ self.emb_dim = emb_dim
195
+ self.n_head = n_head
196
+ def forward(self,x):
197
+ """
198
+ arg:
199
+ x: (torch.Tensor) [B C Q T]
200
+ output:
201
+ output: (torch.Tensor) [B C Q T]
202
+ """
203
+
204
+ input = rearrange(x,"B C Q T -> B C T Q")
205
+ Q = self.norm_Q(self.conv_Q(input)) # [B, n_head, C, T, Q]
206
+ K = self.norm_K(self.conv_K(input))
207
+ V = self.norm_V(self.conv_V(input))
208
+
209
+ Q = rearrange(Q, "B H C T Q -> (B H) T (C Q)")
210
+ K = rearrange(K, "B H C T Q -> (B H) (C Q) T").contiguous()
211
+ batch, n_head, channel, frame, freq = V.shape
212
+ V = rearrange(V, "B H C T Q -> (B H) T (C Q)")
213
+ emb_dim = Q.shape[-1]
214
+ qkT = torch.matmul(Q, K) / (emb_dim**0.5)
215
+ qkT = F.softmax(qkT,dim=2)
216
+ att = torch.matmul(qkT,V)
217
+ att = rearrange(att, "(B H) T (C Q) -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = batch, T=frame)
218
+ att = self.concat_proj(att)
219
+ out = att + input
220
+ out = rearrange(out, "B C T Q -> B C Q T")
221
+ return out
222
+
223
+ class TFGridnetBlock(nn.Module):
224
+ def __init__(
225
+ self,
226
+ emb_dim = 48,
227
+ kernel_size:int = 4,
228
+ emb_hop_size:int = 1,
229
+ n_freqs = 65,
230
+ hidden_channels:int = 192,
231
+ n_head=4,
232
+ qk_output_channel=4,
233
+ activation="PReLU",
234
+ eps = 1e-5
235
+ ):
236
+ super().__init__()
237
+ self.tf_grid_block = nn.Sequential(
238
+ IntraAndInterBandModule(
239
+ emb_dim=emb_dim,
240
+ kernel_size=kernel_size,
241
+ emb_hop_size=emb_hop_size,
242
+ hidden_channels=hidden_channels,
243
+ eps=eps
244
+ ),
245
+ CrossFrameSelfAttention(
246
+ emb_dim=emb_dim,
247
+ n_freqs=n_freqs,
248
+ n_head=n_head,
249
+ qk_output_channel=qk_output_channel,
250
+ activation=activation,
251
+ eps=eps
252
+ )
253
+ )
254
+ def forward(self,input):
255
+ return self.tf_grid_block(input)
network/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .error_message import ErrorMessageUtil
2
+ from .enum_declare import *
network/utils/enum_declare.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class STFT_transform_type_enum:
5
+ exponent = "exponent"
6
+ log = "log"
network/utils/error_message.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class ErrorMessageUtil:
5
+ only_support_batch_input = "we only support batch input. If you use a single input, please call .unsqueeze(0) before."
6
+ complex_format_convert = "we require the input with input.size(1) == 2"
7
+ two_input_in_the_same_shape = "2 input must be in the same shape"
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchaudio==2.0.1
3
+ numpy==1.23.5
4
+ einops==0.7.0
5
+ pandas==1.5.3
6
+ scikit-learn==1.2.2
7
+ resemblyzer
8
+ gradio