hieugiaosu
commited on
Commit
·
7596274
1
Parent(s):
e9096c9
Add application file
Browse files- app.py +102 -0
- checkpoints/concat_emb.pth +3 -0
- network/__init__.py +1 -0
- network/layers/ISTFT_layer.py +69 -0
- network/layers/STFT_Layer.py +67 -0
- network/layers/__init__.py +3 -0
- network/layers/__pycache__/ISTFT_layer.cpython-311.pyc +0 -0
- network/layers/__pycache__/STFTLayer.cpython-311.pyc +0 -0
- network/layers/__pycache__/STFT_Layer.cpython-311.pyc +0 -0
- network/layers/__pycache__/__init__.cpython-311.pyc +0 -0
- network/layers/__pycache__/film_layer.cpython-311.pyc +0 -0
- network/layers/film_layer.py +21 -0
- network/models/TF_gridnet_with_condition.py +638 -0
- network/models/__init__.py +2 -0
- network/models/__pycache__/TF_gridnet.cpython-311.pyc +0 -0
- network/models/__pycache__/TF_gridnet_with_condition.cpython-311.pyc +0 -0
- network/models/__pycache__/TF_gridnet_with_condition_new.cpython-311.pyc +0 -0
- network/models/__pycache__/__init__.cpython-311.pyc +0 -0
- network/models/__pycache__/embedding_model.cpython-311.pyc +0 -0
- network/models/__pycache__/new_style.cpython-311.pyc +0 -0
- network/models/__pycache__/tfGrideNetLOTH.cpython-311.pyc +0 -0
- network/models/embedding_model.py +14 -0
- network/modules/__init__.py +0 -0
- network/modules/attention.py +197 -0
- network/modules/gate_module.py +16 -0
- network/modules/input_tranformation.py +92 -0
- network/modules/output_transformation.py +47 -0
- network/modules/sequence_embed.py +138 -0
- network/modules/split_modules.py +121 -0
- network/modules/tf_gridnet_modules/__init__.py +3 -0
- network/modules/tf_gridnet_modules/__pycache__/__init__.cpython-311.pyc +0 -0
- network/modules/tf_gridnet_modules/__pycache__/deconv.cpython-311.pyc +0 -0
- network/modules/tf_gridnet_modules/__pycache__/dimension_embedding.cpython-311.pyc +0 -0
- network/modules/tf_gridnet_modules/__pycache__/tf_gridnet_block.cpython-311.pyc +0 -0
- network/modules/tf_gridnet_modules/deconv.py +20 -0
- network/modules/tf_gridnet_modules/dimension_embedding.py +15 -0
- network/modules/tf_gridnet_modules/tf_gridnet_block.py +255 -0
- network/utils/__init__.py +2 -0
- network/utils/enum_declare.py +6 -0
- network/utils/error_message.py +7 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
from torch.cuda.amp import autocast
|
5 |
+
|
6 |
+
from network.models import FilterBandTFGridnet, ResemblyzerVoiceEncoder
|
7 |
+
|
8 |
+
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
device = 'cpu'
|
10 |
+
model = FilterBandTFGridnet(n_layers=5,conditional_dim=256*2)
|
11 |
+
emb = ResemblyzerVoiceEncoder(device=device)
|
12 |
+
mixed_voice_tool = None
|
13 |
+
|
14 |
+
def load_voice(voice_path):
|
15 |
+
voice, rate = torchaudio.load(voice_path)
|
16 |
+
|
17 |
+
if rate != 16000:
|
18 |
+
voice = torchaudio.functional.resample(voice, rate, 8000)
|
19 |
+
rate = 16000
|
20 |
+
voice = voice.float()
|
21 |
+
return voice, rate
|
22 |
+
|
23 |
+
def mix(voice1_path, voice2_path, snr=0):
|
24 |
+
global mixed_voice_tool
|
25 |
+
voice1, _ = load_voice(voice1_path)
|
26 |
+
voice2, _ = load_voice(voice2_path)
|
27 |
+
mix = torchaudio.functional.add_noise(voice1, voice2, torch.tensor([float(snr)])).float()
|
28 |
+
mixed_voice_tool = mix
|
29 |
+
return gr.Audio(tuple((16000,mix[0].numpy())),type='numpy')
|
30 |
+
|
31 |
+
# def seprate_from_file(mixed_voice, ref_voice):
|
32 |
+
|
33 |
+
def seprate(mixed_voice_path, clean_voice_path, drop_down):
|
34 |
+
if drop_down == 'From mixing tool':
|
35 |
+
mixed_voice = mixed_voice_tool
|
36 |
+
else:
|
37 |
+
mixed_voice,rate = load_voice(mixed_voice_path)
|
38 |
+
clean_voice,rate = load_voice(clean_voice_path)
|
39 |
+
if clean_voice.shape[-1] < 16000*4:
|
40 |
+
n = 16000*4 // clean_voice.shape[-1] + 1
|
41 |
+
clean_voice = torch.cat([clean_voice]*n, dim=-1)
|
42 |
+
clean_voice = clean_voice[:,:16000*4]
|
43 |
+
if not model:
|
44 |
+
return None
|
45 |
+
model.to(device)
|
46 |
+
model.eval()
|
47 |
+
e = emb(clean_voice)
|
48 |
+
e_mix = emb(mixed_voice)
|
49 |
+
e = torch.cat([e,e_mix],dim=1)
|
50 |
+
mixed_voice = torchaudio.functional.resample(mixed_voice, rate, 8000)
|
51 |
+
with autocast():
|
52 |
+
with torch.no_grad():
|
53 |
+
yHat = model(
|
54 |
+
mixed_voice,
|
55 |
+
e,
|
56 |
+
)
|
57 |
+
yHat = torchaudio.functional.resample(yHat, 8000, 16000).numpy().astype('float32')
|
58 |
+
audio = gr.Audio(tuple((16000,yHat[0])),type='numpy')
|
59 |
+
return audio
|
60 |
+
|
61 |
+
def load_checkpoint(filepath):
|
62 |
+
checkpoint = torch.load(
|
63 |
+
filepath,
|
64 |
+
weights_only=True,
|
65 |
+
map_location=device,
|
66 |
+
)
|
67 |
+
model.load_state_dict(checkpoint)
|
68 |
+
|
69 |
+
with gr.Blocks() as demo:
|
70 |
+
load_checkpoint('checkpoints/concat_emb.pth')
|
71 |
+
with gr.Row():
|
72 |
+
snr = gr.Slider(label='SNR', minimum=-10, maximum=10, step=1, value=0)
|
73 |
+
with gr.Row():
|
74 |
+
with gr.Column(scale=1,min_width=200):
|
75 |
+
voice1 = gr.Audio(label='speaker 1', type='filepath')
|
76 |
+
with gr.Column(scale=1,min_width=200):
|
77 |
+
voice2 = gr.Audio(label='speaker 2', type='filepath')
|
78 |
+
with gr.Column(scale=1,min_width=200):
|
79 |
+
with gr.Row():
|
80 |
+
mixed_voice = gr.Audio(label='Mixed voice')
|
81 |
+
with gr.Row():
|
82 |
+
btn = gr.Button("Mix voices", size='sm')
|
83 |
+
btn.click(mix, inputs=[voice1, voice2, snr], outputs=mixed_voice)
|
84 |
+
with gr.Row():
|
85 |
+
choose_mix_source = gr.Label('Extract target speaker voice from mixed voice')
|
86 |
+
with gr.Row():
|
87 |
+
drop_down = gr.Dropdown(['From mixing tool', 'Upload'], label='Choose mixed voice source')
|
88 |
+
with gr.Row():
|
89 |
+
with gr.Column(scale=1,min_width=200):
|
90 |
+
with gr.Row():
|
91 |
+
mixed_voice_path = gr.Audio(label='Mixed voice', type='filepath')
|
92 |
+
with gr.Column(scale=1,min_width=200):
|
93 |
+
with gr.Row():
|
94 |
+
ref_voice_path = gr.Audio(label='reference voice', type='filepath')
|
95 |
+
with gr.Column(scale=1,min_width=200):
|
96 |
+
with gr.Row():
|
97 |
+
sep_voice = gr.Audio(label="Separate Voice")
|
98 |
+
with gr.Row():
|
99 |
+
btn = gr.Button("Separate voices", size='sm')
|
100 |
+
btn.click(seprate, inputs=[mixed_voice_path, ref_voice_path, drop_down], outputs=sep_voice)
|
101 |
+
demo.launch()
|
102 |
+
|
checkpoints/concat_emb.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c0d3e074f574a701ef9231d787611c47e86556b7ef7d63e1e0f6e4a4a6caa73
|
3 |
+
size 39815976
|
network/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__package__= "network"
|
network/layers/ISTFT_layer.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Optional
|
4 |
+
from ..utils import ErrorMessageUtil
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
class InverseSTFTLayer(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
n_fft:int = 128,
|
11 |
+
win_length: Optional[int] = None,
|
12 |
+
hop_length:int = 64,
|
13 |
+
window: str = "hann",
|
14 |
+
center: bool = True,
|
15 |
+
normalized: bool = False,
|
16 |
+
onesided: bool = True,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.n_fft = n_fft
|
20 |
+
self.win_length = win_length if win_length else n_fft
|
21 |
+
self.hop_length = hop_length
|
22 |
+
self.center = center
|
23 |
+
self.normalized = normalized
|
24 |
+
self.onesided = onesided
|
25 |
+
self.window = getattr(torch,f"{window}_window")
|
26 |
+
def forward(self,input,audio_length:int):
|
27 |
+
"""STFT forward function.
|
28 |
+
Args:
|
29 |
+
input: (Batch, Freq, Frames) or (Batch, Channels, Freq, Frames)
|
30 |
+
Returns:
|
31 |
+
output: (Batch, Nsamples) or (Batch, Channel, Nsample)
|
32 |
+
|
33 |
+
Notice:
|
34 |
+
input is a complex tensor
|
35 |
+
"""
|
36 |
+
assert input.dim() == 4 or input.dim() == 3, ErrorMessageUtil.only_support_batch_input
|
37 |
+
batch_size = input.size(0)
|
38 |
+
multi_channel = (input.dim() == 4)
|
39 |
+
if multi_channel:
|
40 |
+
input = rearrange(input, "b c f t -> (b c) f t")
|
41 |
+
window = self.window(
|
42 |
+
self.win_length,
|
43 |
+
dtype = input.real.dtype,
|
44 |
+
device = input.device
|
45 |
+
)
|
46 |
+
istft_kwargs = dict(
|
47 |
+
n_fft=self.n_fft,
|
48 |
+
win_length=self.n_fft,
|
49 |
+
hop_length=self.hop_length,
|
50 |
+
center=self.center,
|
51 |
+
window=window,
|
52 |
+
length = audio_length,
|
53 |
+
return_complex = False
|
54 |
+
)
|
55 |
+
|
56 |
+
wave = torch.istft(input,**istft_kwargs)
|
57 |
+
if multi_channel:
|
58 |
+
wave = rearrange(wave,"(b c) l -> b c l", b = batch_size)
|
59 |
+
return wave
|
60 |
+
|
61 |
+
class ComplexTensorLayer(nn.Module):
|
62 |
+
def __init__(self):
|
63 |
+
super().__init__()
|
64 |
+
def forward(seal,input):
|
65 |
+
assert input.shape[1] == 2, ErrorMessageUtil.complex_format_convert
|
66 |
+
real = input[:,0]
|
67 |
+
imag = input[:,1]
|
68 |
+
|
69 |
+
return torch.complex(real,imag)
|
network/layers/STFT_Layer.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Optional
|
4 |
+
from ..utils import ErrorMessageUtil
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
class STFTLayer(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
n_fft:int = 128,
|
11 |
+
win_length: Optional[int] = None,
|
12 |
+
hop_length:int = 64,
|
13 |
+
window: str = "hann",
|
14 |
+
center: bool = True,
|
15 |
+
normalized: bool = False,
|
16 |
+
onesided: bool = True,
|
17 |
+
pad_mode:str ="reflect"
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.n_fft = n_fft
|
21 |
+
self.win_length = win_length if win_length else n_fft
|
22 |
+
self.hop_length = hop_length
|
23 |
+
self.center = center
|
24 |
+
self.normalized = normalized
|
25 |
+
self.onesided = onesided
|
26 |
+
self.pad_mode = pad_mode
|
27 |
+
self.window = getattr(torch,f"{window}_window")
|
28 |
+
def forward(self,input:torch.Tensor):
|
29 |
+
"""STFT forward function.
|
30 |
+
Args:
|
31 |
+
input: (Batch, Nsamples) or (Batch, Channel, Nsample)
|
32 |
+
Returns:
|
33 |
+
output: (Batch, Freq, Frames) or (Batch, Channels, Freq, Frames)
|
34 |
+
Notice:
|
35 |
+
output is a complex tensor
|
36 |
+
"""
|
37 |
+
assert input.dim() == 2 or input.dim() == 3, ErrorMessageUtil.only_support_batch_input
|
38 |
+
batch_size = input.size(0)
|
39 |
+
multi_channel = (input.dim() == 3)
|
40 |
+
if multi_channel:
|
41 |
+
input = rearrange(input, "b c l -> (b c) l")
|
42 |
+
window = self.window(
|
43 |
+
self.win_length,
|
44 |
+
dtype = input.dtype,
|
45 |
+
device = input.device
|
46 |
+
)
|
47 |
+
|
48 |
+
stft_kwargs = dict(
|
49 |
+
n_fft=self.n_fft,
|
50 |
+
win_length=self.n_fft,
|
51 |
+
hop_length=self.hop_length,
|
52 |
+
center=self.center,
|
53 |
+
window=window,
|
54 |
+
pad_mode=self.pad_mode,
|
55 |
+
return_complex=True
|
56 |
+
)
|
57 |
+
|
58 |
+
n_pad_left = (self.n_fft - window.shape[0]) // 2
|
59 |
+
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
|
60 |
+
stft_kwargs["window"] = torch.cat(
|
61 |
+
[torch.zeros(n_pad_left,device=input.device), window, torch.zeros(n_pad_right,device=input.device)], 0
|
62 |
+
)
|
63 |
+
|
64 |
+
output = torch.stft(input,**stft_kwargs)
|
65 |
+
if multi_channel:
|
66 |
+
output = rearrange(output,"(b c) f t -> b c f t", b = batch_size)
|
67 |
+
return output
|
network/layers/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .STFT_Layer import STFTLayer
|
2 |
+
from .ISTFT_layer import *
|
3 |
+
from .film_layer import FiLMLayer
|
network/layers/__pycache__/ISTFT_layer.cpython-311.pyc
ADDED
Binary file (4.13 kB). View file
|
|
network/layers/__pycache__/STFTLayer.cpython-311.pyc
ADDED
Binary file (3.66 kB). View file
|
|
network/layers/__pycache__/STFT_Layer.cpython-311.pyc
ADDED
Binary file (3.63 kB). View file
|
|
network/layers/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (358 Bytes). View file
|
|
network/layers/__pycache__/film_layer.cpython-311.pyc
ADDED
Binary file (2.04 kB). View file
|
|
network/layers/film_layer.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
class FiLMLayer(nn.Module):
|
5 |
+
def __init__(self,channels,conditional_dim=256, apply_dim = 1):
|
6 |
+
super().__init__()
|
7 |
+
self.alpha = nn.Linear(conditional_dim,channels)
|
8 |
+
self.beta = nn.Linear(conditional_dim,channels)
|
9 |
+
self.apply_dim = apply_dim
|
10 |
+
def forward(self,x,condition):
|
11 |
+
alpha = self.alpha(condition)
|
12 |
+
beta = self.beta(condition)
|
13 |
+
input = x
|
14 |
+
if self.apply_dim != 1:
|
15 |
+
input = input.transpose(1,-1)
|
16 |
+
alpha = rearrange(alpha,"b d -> b d"+" 1"*(x.dim()-alpha.dim()))
|
17 |
+
beta = rearrange(beta,"b d -> b d"+" 1"*(x.dim()-beta.dim()))
|
18 |
+
out = alpha*input+beta
|
19 |
+
if self.apply_dim != 1:
|
20 |
+
out = out.transpose(1,-1)
|
21 |
+
return out
|
network/models/TF_gridnet_with_condition.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from ..modules.tf_gridnet_modules import *
|
3 |
+
from ..modules.input_tranformation import STFTInput, RMSNormalizeInput
|
4 |
+
from ..modules.output_transformation import WaveGeneratorByISTFT, RMSDenormalizeOutput
|
5 |
+
from ..modules.convolution_module import SplitFeatureDeconv
|
6 |
+
from ..modules.attention import *
|
7 |
+
from .TF_gridnet import TF_Gridnet
|
8 |
+
from ..layers import FiLMLayer
|
9 |
+
from ..modules.gate_module import BandFilterGate
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
import math
|
12 |
+
class TFGridFormer(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
n_srcs=2,
|
16 |
+
n_fft=128,
|
17 |
+
hop_length=64,
|
18 |
+
window="hann",
|
19 |
+
n_audio_channel=1,
|
20 |
+
n_layers=6,
|
21 |
+
input_kernel_size_T = 3,
|
22 |
+
input_kernel_size_F = 3,
|
23 |
+
output_kernel_size_T = 3,
|
24 |
+
output_kernel_size_F = 3,
|
25 |
+
lstm_hidden_units=192,
|
26 |
+
attn_n_head=4,
|
27 |
+
qk_output_channel=4,
|
28 |
+
emb_dim=48,
|
29 |
+
emb_ks=4,
|
30 |
+
emb_hs=1,
|
31 |
+
activation="PReLU",
|
32 |
+
eps=1.0e-5,):
|
33 |
+
super().__init__()
|
34 |
+
self.ref_input_normalize = RMSNormalizeInput((1,2),keepdim=True)
|
35 |
+
self.mix_input_normalize = RMSNormalizeInput((1,2),keepdim=True)
|
36 |
+
self.stft = STFTInput(
|
37 |
+
n_fft=n_fft,
|
38 |
+
win_length=n_fft,
|
39 |
+
hop_length=hop_length,
|
40 |
+
window=window,
|
41 |
+
)
|
42 |
+
|
43 |
+
self.mix_dimension_embedding = DimensionEmbedding(
|
44 |
+
audio_channel=n_audio_channel,
|
45 |
+
emb_dim=emb_dim,
|
46 |
+
kernel_size=(input_kernel_size_F,input_kernel_size_T),
|
47 |
+
eps=eps
|
48 |
+
)
|
49 |
+
|
50 |
+
self.ref_dimension_embedding = DimensionEmbedding(
|
51 |
+
audio_channel=n_audio_channel,
|
52 |
+
emb_dim=emb_dim,
|
53 |
+
kernel_size=(input_kernel_size_F,input_kernel_size_T),
|
54 |
+
eps=eps
|
55 |
+
)
|
56 |
+
|
57 |
+
self.istft = WaveGeneratorByISTFT(
|
58 |
+
n_fft=n_fft,
|
59 |
+
win_length=n_fft,
|
60 |
+
hop_length=hop_length,
|
61 |
+
window=window
|
62 |
+
)
|
63 |
+
|
64 |
+
self.output_denormalize = RMSDenormalizeOutput()
|
65 |
+
|
66 |
+
self.ref_encoder = TFGridnetBlock(
|
67 |
+
emb_dim=emb_dim,
|
68 |
+
kernel_size=emb_ks,
|
69 |
+
emb_hop_size=emb_hs,
|
70 |
+
hidden_channels=lstm_hidden_units,
|
71 |
+
n_head=attn_n_head,
|
72 |
+
qk_output_channel=qk_output_channel,
|
73 |
+
activation=activation,
|
74 |
+
eps=eps
|
75 |
+
)
|
76 |
+
|
77 |
+
mix_encode_layers = math.ceil(n_layers*2/3)
|
78 |
+
mix_decode_layers = n_layers - mix_encode_layers
|
79 |
+
|
80 |
+
self.mix_encoder = nn.Sequential(
|
81 |
+
*[
|
82 |
+
TFGridnetBlock(
|
83 |
+
emb_dim=emb_dim,
|
84 |
+
kernel_size=emb_ks,
|
85 |
+
emb_hop_size=emb_hs,
|
86 |
+
hidden_channels=lstm_hidden_units,
|
87 |
+
n_head=attn_n_head,
|
88 |
+
qk_output_channel=qk_output_channel,
|
89 |
+
activation=activation,
|
90 |
+
eps=eps
|
91 |
+
) for _ in range(mix_encode_layers)
|
92 |
+
]
|
93 |
+
)
|
94 |
+
|
95 |
+
self.split_layer = SplitFeatureDeconv(
|
96 |
+
emb_dim=emb_dim,
|
97 |
+
n_srcs=n_srcs,
|
98 |
+
kernel_size_T=output_kernel_size_T,
|
99 |
+
kernel_size_F=output_kernel_size_F,
|
100 |
+
padding_F=output_kernel_size_F//2,
|
101 |
+
padding_T=output_kernel_size_T//2
|
102 |
+
)
|
103 |
+
|
104 |
+
self.intra_frame_cross_att = IntraFrameCrossAttention(
|
105 |
+
emb_dim=emb_dim,
|
106 |
+
n_head= attn_n_head,
|
107 |
+
qk_output_channel=attn_n_head*3,
|
108 |
+
activation=activation,
|
109 |
+
eps=eps
|
110 |
+
)
|
111 |
+
self.cross_frame_cross_att = CrossFrameCrossAttention(
|
112 |
+
emb_dim=emb_dim,
|
113 |
+
n_head=attn_n_head,
|
114 |
+
qk_output_channel=qk_output_channel,
|
115 |
+
activation=activation,
|
116 |
+
eps=eps
|
117 |
+
)
|
118 |
+
|
119 |
+
self.mix_decoder = nn.Sequential(
|
120 |
+
*[
|
121 |
+
TFGridnetBlock(
|
122 |
+
emb_dim=emb_dim,
|
123 |
+
kernel_size=emb_ks,
|
124 |
+
emb_hop_size=emb_hs,
|
125 |
+
hidden_channels=lstm_hidden_units,
|
126 |
+
n_head=attn_n_head,
|
127 |
+
qk_output_channel=qk_output_channel,
|
128 |
+
activation=activation,
|
129 |
+
eps=eps
|
130 |
+
) for _ in range(mix_decode_layers)
|
131 |
+
]
|
132 |
+
)
|
133 |
+
|
134 |
+
self.deconv = TFGridnetDeconv(
|
135 |
+
emb_dim=emb_dim,
|
136 |
+
n_srcs=1,
|
137 |
+
kernel_size_T=output_kernel_size_T,
|
138 |
+
kernel_size_F=output_kernel_size_F,
|
139 |
+
padding_F=output_kernel_size_F//2,
|
140 |
+
padding_T=output_kernel_size_T//2
|
141 |
+
)
|
142 |
+
|
143 |
+
self.middle_deconv = TFGridnetDeconv(
|
144 |
+
emb_dim=emb_dim,
|
145 |
+
n_srcs=n_srcs,
|
146 |
+
kernel_size_T=output_kernel_size_T,
|
147 |
+
kernel_size_F=output_kernel_size_F,
|
148 |
+
padding_F=output_kernel_size_F//2,
|
149 |
+
padding_T=output_kernel_size_T//2
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(self,mix,ref,middle = False):
|
153 |
+
audio_length = mix.shape[-1]
|
154 |
+
|
155 |
+
x = mix
|
156 |
+
c = ref
|
157 |
+
|
158 |
+
if x.dim() == 2:
|
159 |
+
x = x.unsqueeze(1)
|
160 |
+
if c.dim() == 2:
|
161 |
+
c = c.unsqueeze(1)
|
162 |
+
x, std = self.mix_input_normalize(x)
|
163 |
+
|
164 |
+
x = self.stft(x)
|
165 |
+
|
166 |
+
c, _ = self.ref_input_normalize(c)
|
167 |
+
|
168 |
+
c = self.stft(c)
|
169 |
+
|
170 |
+
c = self.ref_dimension_embedding(c)
|
171 |
+
|
172 |
+
c = self.ref_encoder(c)
|
173 |
+
|
174 |
+
x = self.mix_dimension_embedding(x)
|
175 |
+
|
176 |
+
x = self.mix_encoder(x)
|
177 |
+
|
178 |
+
m = None
|
179 |
+
if middle:
|
180 |
+
m = self.middle_deconv(x)
|
181 |
+
m = rearrange(m,"B C N F T -> B N C F T")
|
182 |
+
m = self.istft(m,audio_length)
|
183 |
+
m = self.output_denormalize(m,std)
|
184 |
+
|
185 |
+
|
186 |
+
x = self.split_layer(x)
|
187 |
+
|
188 |
+
x = self.intra_frame_cross_att(c,x)
|
189 |
+
|
190 |
+
x = self.cross_frame_cross_att(c,x)
|
191 |
+
|
192 |
+
x = self.mix_decoder(x)
|
193 |
+
|
194 |
+
x = self.deconv(x)
|
195 |
+
|
196 |
+
x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
|
197 |
+
|
198 |
+
x = self.istft(x,audio_length)
|
199 |
+
|
200 |
+
x = self.output_denormalize(x,std)
|
201 |
+
if middle: return x[:,0], m
|
202 |
+
return x[:,0]
|
203 |
+
|
204 |
+
class DoubleChannelTFGridNet(TF_Gridnet):
|
205 |
+
def __init__(self,
|
206 |
+
# n_srcs=2,
|
207 |
+
n_fft=128,
|
208 |
+
hop_length=64,
|
209 |
+
window="hann",
|
210 |
+
n_audio_channel=1,
|
211 |
+
n_layers=6,
|
212 |
+
input_kernel_size_T=3,
|
213 |
+
input_kernel_size_F=3,
|
214 |
+
output_kernel_size_T=3,
|
215 |
+
output_kernel_size_F=3,
|
216 |
+
lstm_hidden_units=192,
|
217 |
+
attn_n_head=4,
|
218 |
+
qk_output_channel=4,
|
219 |
+
emb_dim=48,
|
220 |
+
emb_ks=4,
|
221 |
+
emb_hs=1,
|
222 |
+
activation="PReLU",
|
223 |
+
eps=0.00001):
|
224 |
+
super().__init__(1, n_fft, hop_length, window, n_audio_channel*2, n_layers, input_kernel_size_T, input_kernel_size_F, output_kernel_size_T, output_kernel_size_F, lstm_hidden_units, attn_n_head, qk_output_channel, emb_dim, emb_ks, emb_hs, activation, eps)
|
225 |
+
def forward(self,input,condition):
|
226 |
+
x = input
|
227 |
+
c = condition
|
228 |
+
|
229 |
+
if x.dim() == 2:
|
230 |
+
x = x.unsqueeze(1)
|
231 |
+
if c.dim() == 2:
|
232 |
+
c = c.unsqueeze(1)
|
233 |
+
tc = c.shape[-1]
|
234 |
+
tx = x.shape[-1]
|
235 |
+
if tc >= tx:
|
236 |
+
c = c[:,:,-tx:]
|
237 |
+
else:
|
238 |
+
n = math.ceil(tx/tc)
|
239 |
+
c = repeat(c,"b c t -> b c (t n)",n=n)
|
240 |
+
c = c[:,:,-tx:]
|
241 |
+
|
242 |
+
mix_with_clue = torch.cat([x,c],dim=1)
|
243 |
+
o = super().forward(mix_with_clue)
|
244 |
+
return o[:,0]
|
245 |
+
|
246 |
+
|
247 |
+
class TargetSpeakerTF(nn.Module):
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
n_srcs=2,
|
251 |
+
n_fft=128,
|
252 |
+
hop_length=64,
|
253 |
+
window="hann",
|
254 |
+
n_audio_channel=1,
|
255 |
+
n_layers=6,
|
256 |
+
input_kernel_size_T = 3,
|
257 |
+
input_kernel_size_F = 3,
|
258 |
+
output_kernel_size_T = 3,
|
259 |
+
output_kernel_size_F = 3,
|
260 |
+
lstm_hidden_units=192,
|
261 |
+
attn_n_head=4,
|
262 |
+
qk_output_channel=4,
|
263 |
+
emb_dim=48,
|
264 |
+
emb_ks=4,
|
265 |
+
emb_hs=1,
|
266 |
+
activation="PReLU",
|
267 |
+
eps=1.0e-5,
|
268 |
+
conditional_dim = 256
|
269 |
+
):
|
270 |
+
super().__init__()
|
271 |
+
|
272 |
+
self.input_normalize = RMSNormalizeInput((1,2),keepdim=True)
|
273 |
+
self.stft = STFTInput(
|
274 |
+
n_fft=n_fft,
|
275 |
+
win_length=n_fft,
|
276 |
+
hop_length=hop_length,
|
277 |
+
window=window,
|
278 |
+
)
|
279 |
+
|
280 |
+
self.istft = WaveGeneratorByISTFT(
|
281 |
+
n_fft=n_fft,
|
282 |
+
win_length=n_fft,
|
283 |
+
hop_length=hop_length,
|
284 |
+
window=window
|
285 |
+
)
|
286 |
+
|
287 |
+
self.output_denormalize = RMSDenormalizeOutput()
|
288 |
+
|
289 |
+
self.dimension_embedding = DimensionEmbedding(
|
290 |
+
audio_channel=n_audio_channel,
|
291 |
+
emb_dim=emb_dim,
|
292 |
+
kernel_size=(input_kernel_size_F,input_kernel_size_T),
|
293 |
+
eps=eps
|
294 |
+
)
|
295 |
+
|
296 |
+
self.tf_gridnet_block = nn.ModuleList(
|
297 |
+
[
|
298 |
+
TFGridnetBlock(
|
299 |
+
emb_dim=emb_dim,
|
300 |
+
kernel_size=emb_ks,
|
301 |
+
emb_hop_size=emb_hs,
|
302 |
+
hidden_channels=lstm_hidden_units,
|
303 |
+
n_head=attn_n_head,
|
304 |
+
qk_output_channel=qk_output_channel,
|
305 |
+
activation=activation,
|
306 |
+
eps=eps
|
307 |
+
) for _ in range(n_layers)
|
308 |
+
]
|
309 |
+
)
|
310 |
+
|
311 |
+
self.film_layer = nn.ModuleList(
|
312 |
+
[
|
313 |
+
FiLMLayer(emb_dim,conditional_dim=conditional_dim,apply_dim=1)
|
314 |
+
for _ in range(n_layers)
|
315 |
+
]
|
316 |
+
)
|
317 |
+
|
318 |
+
self.deconv = TFGridnetDeconv(
|
319 |
+
emb_dim=emb_dim,
|
320 |
+
n_srcs=n_srcs,
|
321 |
+
kernel_size_T=output_kernel_size_T,
|
322 |
+
kernel_size_F=output_kernel_size_F,
|
323 |
+
padding_F=output_kernel_size_F//2,
|
324 |
+
padding_T=output_kernel_size_T//2
|
325 |
+
)
|
326 |
+
|
327 |
+
self.n_layers = n_layers
|
328 |
+
|
329 |
+
def forward(self,input, clue):
|
330 |
+
audio_length = input.shape[-1]
|
331 |
+
|
332 |
+
x = input
|
333 |
+
|
334 |
+
if x.dim() == 2:
|
335 |
+
x = x.unsqueeze(1)
|
336 |
+
|
337 |
+
x, std = self.input_normalize(x)
|
338 |
+
|
339 |
+
x = self.stft(x)
|
340 |
+
|
341 |
+
x = self.dimension_embedding(x)
|
342 |
+
for i in range(self.n_layers):
|
343 |
+
|
344 |
+
x = self.tf_gridnet_block[i](x)
|
345 |
+
x = self.film_layer[i](x,clue)
|
346 |
+
|
347 |
+
x = self.deconv(x)
|
348 |
+
|
349 |
+
x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
|
350 |
+
|
351 |
+
x = self.istft(x,audio_length)
|
352 |
+
|
353 |
+
x = self.output_denormalize(x,std)
|
354 |
+
|
355 |
+
return x
|
356 |
+
|
357 |
+
class DoubleChannelTargetSpeakerTF(TargetSpeakerTF):
|
358 |
+
def __init__(self,
|
359 |
+
# n_srcs=2,
|
360 |
+
n_fft=128,
|
361 |
+
hop_length=64,
|
362 |
+
window="hann",
|
363 |
+
n_audio_channel=1,
|
364 |
+
n_layers=6,
|
365 |
+
input_kernel_size_T=3,
|
366 |
+
input_kernel_size_F=3,
|
367 |
+
output_kernel_size_T=3,
|
368 |
+
output_kernel_size_F=3,
|
369 |
+
lstm_hidden_units=192,
|
370 |
+
attn_n_head=4,
|
371 |
+
qk_output_channel=4,
|
372 |
+
emb_dim=48,
|
373 |
+
emb_ks=4,
|
374 |
+
emb_hs=1,
|
375 |
+
activation="PReLU",
|
376 |
+
eps=0.00001,
|
377 |
+
conditional_dim = 256
|
378 |
+
):
|
379 |
+
super().__init__(1, n_fft, hop_length, window, n_audio_channel*2, n_layers, input_kernel_size_T, input_kernel_size_F, output_kernel_size_T, output_kernel_size_F, lstm_hidden_units, attn_n_head, qk_output_channel, emb_dim, emb_ks, emb_hs, activation, eps, conditional_dim)
|
380 |
+
def forward(self, input, reference, embedding):
|
381 |
+
x = input
|
382 |
+
c = reference
|
383 |
+
|
384 |
+
if x.dim() == 2:
|
385 |
+
x = x.unsqueeze(1)
|
386 |
+
if c.dim() == 2:
|
387 |
+
c = c.unsqueeze(1)
|
388 |
+
tc = c.shape[-1]
|
389 |
+
tx = x.shape[-1]
|
390 |
+
if tc >= tx:
|
391 |
+
c = c[:,:,-tx:]
|
392 |
+
else:
|
393 |
+
n = math.ceil(tx/tc)
|
394 |
+
c = repeat(c,"b c t -> b c (t n)",n=n)
|
395 |
+
c = c[:,:,-tx:]
|
396 |
+
|
397 |
+
mix_with_clue = torch.cat([x,c],dim=1)
|
398 |
+
o = super().forward(mix_with_clue,embedding)
|
399 |
+
return o[:,0]
|
400 |
+
|
401 |
+
class FilterBandTFGridnet(nn.Module):
|
402 |
+
def __init__(
|
403 |
+
self,
|
404 |
+
# n_srcs=2,
|
405 |
+
n_fft=128,
|
406 |
+
hop_length=64,
|
407 |
+
window="hann",
|
408 |
+
n_audio_channel=1,
|
409 |
+
n_layers=6,
|
410 |
+
input_kernel_size_T = 3,
|
411 |
+
input_kernel_size_F = 3,
|
412 |
+
output_kernel_size_T = 3,
|
413 |
+
output_kernel_size_F = 3,
|
414 |
+
lstm_hidden_units=192,
|
415 |
+
attn_n_head=4,
|
416 |
+
qk_output_channel=4,
|
417 |
+
emb_dim=48,
|
418 |
+
emb_ks=4,
|
419 |
+
emb_hs=1,
|
420 |
+
activation="PReLU",
|
421 |
+
eps=1.0e-5,
|
422 |
+
conditional_dim = 256
|
423 |
+
):
|
424 |
+
super().__init__()
|
425 |
+
n_freqs = n_fft//2 + 1
|
426 |
+
self.input_normalize = RMSNormalizeInput((1,2),keepdim=True)
|
427 |
+
self.stft = STFTInput(
|
428 |
+
n_fft=n_fft,
|
429 |
+
win_length=n_fft,
|
430 |
+
hop_length=hop_length,
|
431 |
+
window=window,
|
432 |
+
)
|
433 |
+
|
434 |
+
self.istft = WaveGeneratorByISTFT(
|
435 |
+
n_fft=n_fft,
|
436 |
+
win_length=n_fft,
|
437 |
+
hop_length=hop_length,
|
438 |
+
window=window
|
439 |
+
)
|
440 |
+
|
441 |
+
self.output_denormalize = RMSDenormalizeOutput()
|
442 |
+
|
443 |
+
self.dimension_embedding = DimensionEmbedding(
|
444 |
+
audio_channel=n_audio_channel,
|
445 |
+
emb_dim=emb_dim,
|
446 |
+
kernel_size=(input_kernel_size_F,input_kernel_size_T),
|
447 |
+
eps=eps
|
448 |
+
)
|
449 |
+
|
450 |
+
self.tf_gridnet_block = nn.ModuleList(
|
451 |
+
[
|
452 |
+
TFGridnetBlock(
|
453 |
+
emb_dim=emb_dim,
|
454 |
+
kernel_size=emb_ks,
|
455 |
+
emb_hop_size=emb_hs,
|
456 |
+
hidden_channels=lstm_hidden_units,
|
457 |
+
n_head=attn_n_head,
|
458 |
+
qk_output_channel=qk_output_channel,
|
459 |
+
activation=activation,
|
460 |
+
eps=eps
|
461 |
+
) for _ in range(n_layers)
|
462 |
+
]
|
463 |
+
)
|
464 |
+
|
465 |
+
self.filter_gen = nn.Linear(conditional_dim,emb_dim*n_freqs)
|
466 |
+
self.bias_gen = nn.Linear(conditional_dim,emb_dim*n_freqs)
|
467 |
+
|
468 |
+
|
469 |
+
self.gates = nn.ModuleList(
|
470 |
+
[
|
471 |
+
BandFilterGate(emb_dim,n_freqs)
|
472 |
+
for _ in range(n_layers)
|
473 |
+
]
|
474 |
+
)
|
475 |
+
|
476 |
+
self.deconv = TFGridnetDeconv(
|
477 |
+
emb_dim=emb_dim,
|
478 |
+
n_srcs=1,
|
479 |
+
kernel_size_T=output_kernel_size_T,
|
480 |
+
kernel_size_F=output_kernel_size_F,
|
481 |
+
padding_F=output_kernel_size_F//2,
|
482 |
+
padding_T=output_kernel_size_T//2
|
483 |
+
)
|
484 |
+
|
485 |
+
self.n_layers = n_layers
|
486 |
+
def forward(self,input, clue):
|
487 |
+
audio_length = input.shape[-1]
|
488 |
+
|
489 |
+
x = input
|
490 |
+
|
491 |
+
if x.dim() == 2:
|
492 |
+
x = x.unsqueeze(1)
|
493 |
+
|
494 |
+
x, std = self.input_normalize(x)
|
495 |
+
|
496 |
+
x = self.stft(x)
|
497 |
+
|
498 |
+
x = self.dimension_embedding(x)
|
499 |
+
|
500 |
+
n_freqs = x.shape[-2]
|
501 |
+
f = self.filter_gen(clue)
|
502 |
+
b = self.bias_gen(clue)
|
503 |
+
f = rearrange(f,"b (d q) -> b d q 1", q = n_freqs)
|
504 |
+
b = rearrange(b,"b (d q) -> b d q 1", q = n_freqs)
|
505 |
+
|
506 |
+
for i in range(self.n_layers):
|
507 |
+
|
508 |
+
x = self.tf_gridnet_block[i](x)
|
509 |
+
x = self.gates[i](x,f,b)
|
510 |
+
|
511 |
+
x = self.deconv(x)
|
512 |
+
|
513 |
+
x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
|
514 |
+
|
515 |
+
x = self.istft(x,audio_length)
|
516 |
+
|
517 |
+
x = self.output_denormalize(x,std)
|
518 |
+
|
519 |
+
return x[:,0]
|
520 |
+
|
521 |
+
class FilterBandTFGridnetWithAttentionGate(nn.Module):
|
522 |
+
def __init__(
|
523 |
+
self,
|
524 |
+
# n_srcs=2,
|
525 |
+
n_fft=128,
|
526 |
+
hop_length=64,
|
527 |
+
window="hann",
|
528 |
+
n_audio_channel=1,
|
529 |
+
n_layers=6,
|
530 |
+
input_kernel_size_T = 3,
|
531 |
+
input_kernel_size_F = 3,
|
532 |
+
output_kernel_size_T = 3,
|
533 |
+
output_kernel_size_F = 3,
|
534 |
+
lstm_hidden_units=192,
|
535 |
+
attn_n_head=4,
|
536 |
+
qk_output_channel=4,
|
537 |
+
emb_dim=48,
|
538 |
+
emb_ks=4,
|
539 |
+
emb_hs=1,
|
540 |
+
activation="PReLU",
|
541 |
+
eps=1.0e-5,
|
542 |
+
conditional_dim = 256
|
543 |
+
):
|
544 |
+
super().__init__()
|
545 |
+
n_freqs = n_fft//2 + 1
|
546 |
+
self.input_normalize = RMSNormalizeInput((1,2),keepdim=True)
|
547 |
+
self.stft = STFTInput(
|
548 |
+
n_fft=n_fft,
|
549 |
+
win_length=n_fft,
|
550 |
+
hop_length=hop_length,
|
551 |
+
window=window,
|
552 |
+
)
|
553 |
+
|
554 |
+
self.istft = WaveGeneratorByISTFT(
|
555 |
+
n_fft=n_fft,
|
556 |
+
win_length=n_fft,
|
557 |
+
hop_length=hop_length,
|
558 |
+
window=window
|
559 |
+
)
|
560 |
+
|
561 |
+
self.output_denormalize = RMSDenormalizeOutput()
|
562 |
+
|
563 |
+
self.dimension_embedding = DimensionEmbedding(
|
564 |
+
audio_channel=n_audio_channel,
|
565 |
+
emb_dim=emb_dim,
|
566 |
+
kernel_size=(input_kernel_size_F,input_kernel_size_T),
|
567 |
+
eps=eps
|
568 |
+
)
|
569 |
+
|
570 |
+
self.tf_gridnet_block = nn.ModuleList(
|
571 |
+
[
|
572 |
+
TFGridnetBlock(
|
573 |
+
emb_dim=emb_dim,
|
574 |
+
kernel_size=emb_ks,
|
575 |
+
emb_hop_size=emb_hs,
|
576 |
+
hidden_channels=lstm_hidden_units,
|
577 |
+
n_head=attn_n_head,
|
578 |
+
qk_output_channel=qk_output_channel,
|
579 |
+
activation=activation,
|
580 |
+
eps=eps
|
581 |
+
) for _ in range(n_layers)
|
582 |
+
]
|
583 |
+
)
|
584 |
+
|
585 |
+
self.query_gen = nn.Linear(conditional_dim,emb_dim*n_freqs)
|
586 |
+
|
587 |
+
|
588 |
+
self.attentions = nn.ModuleList(
|
589 |
+
[
|
590 |
+
CrossAttentionFilterV2(emb_dim)
|
591 |
+
for _ in range(n_layers)
|
592 |
+
]
|
593 |
+
)
|
594 |
+
|
595 |
+
self.deconv = TFGridnetDeconv(
|
596 |
+
emb_dim=emb_dim,
|
597 |
+
n_srcs=1,
|
598 |
+
kernel_size_T=output_kernel_size_T,
|
599 |
+
kernel_size_F=output_kernel_size_F,
|
600 |
+
padding_F=output_kernel_size_F//2,
|
601 |
+
padding_T=output_kernel_size_T//2
|
602 |
+
)
|
603 |
+
|
604 |
+
self.n_layers = n_layers
|
605 |
+
|
606 |
+
def forward(self,input, clue):
|
607 |
+
audio_length = input.shape[-1]
|
608 |
+
|
609 |
+
x = input
|
610 |
+
|
611 |
+
if x.dim() == 2:
|
612 |
+
x = x.unsqueeze(1)
|
613 |
+
|
614 |
+
x, std = self.input_normalize(x)
|
615 |
+
|
616 |
+
x = self.stft(x)
|
617 |
+
|
618 |
+
x = self.dimension_embedding(x)
|
619 |
+
|
620 |
+
n_freqs = x.shape[-2]
|
621 |
+
|
622 |
+
q = self.query_gen(clue)
|
623 |
+
q = rearrange(q,"b (d f) -> b f d", f=n_freqs)
|
624 |
+
|
625 |
+
for i in range(self.n_layers):
|
626 |
+
|
627 |
+
x = self.tf_gridnet_block[i](x)
|
628 |
+
x = self.attentions[i](q,x)
|
629 |
+
|
630 |
+
x = self.deconv(x)
|
631 |
+
|
632 |
+
x = rearrange(x,"B C N F T -> B N C F T") #becasue in istft, the 1 dim is for real and im part
|
633 |
+
|
634 |
+
x = self.istft(x,audio_length)
|
635 |
+
|
636 |
+
x = self.output_denormalize(x,std)
|
637 |
+
|
638 |
+
return x[:,0]
|
network/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .TF_gridnet_with_condition import *
|
2 |
+
from .embedding_model import ResemblyzerVoiceEncoder
|
network/models/__pycache__/TF_gridnet.cpython-311.pyc
ADDED
Binary file (4 kB). View file
|
|
network/models/__pycache__/TF_gridnet_with_condition.cpython-311.pyc
ADDED
Binary file (22.4 kB). View file
|
|
network/models/__pycache__/TF_gridnet_with_condition_new.cpython-311.pyc
ADDED
Binary file (5.3 kB). View file
|
|
network/models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (592 Bytes). View file
|
|
network/models/__pycache__/embedding_model.cpython-311.pyc
ADDED
Binary file (2.14 kB). View file
|
|
network/models/__pycache__/new_style.cpython-311.pyc
ADDED
Binary file (8.66 kB). View file
|
|
network/models/__pycache__/tfGrideNetLOTH.cpython-311.pyc
ADDED
Binary file (5.41 kB). View file
|
|
network/models/embedding_model.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from resemblyzer import VoiceEncoder
|
3 |
+
|
4 |
+
class ResemblyzerVoiceEncoder:
|
5 |
+
def __init__(self, device) -> None:
|
6 |
+
self.model = VoiceEncoder(device)
|
7 |
+
|
8 |
+
def __call__(self, audio: torch.Tensor):
|
9 |
+
if audio.ndimension() == 1:
|
10 |
+
return torch.tensor(self.model.embed_utterance(audio.numpy())).float().cpu()
|
11 |
+
else:
|
12 |
+
e = torch.stack([torch.tensor(self.model.embed_utterance(audio[i,:].numpy())).float().cpu()
|
13 |
+
for i in range(audio.shape[0])])
|
14 |
+
return e
|
network/modules/__init__.py
ADDED
File without changes
|
network/modules/attention.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .tf_gridnet_modules import AllHeadPReLULayerNormalization4DC, LayerNormalization
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
class IntraFrameCrossAttention(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
emb_dim = 48,
|
13 |
+
n_head = 4,
|
14 |
+
qk_output_channel=12,
|
15 |
+
activation="PReLU",
|
16 |
+
eps = 1e-5
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
assert emb_dim % n_head == 0
|
20 |
+
E = qk_output_channel
|
21 |
+
self.conv_Q = nn.Conv2d(emb_dim,n_head*E,1)
|
22 |
+
self.norm_Q = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
|
23 |
+
|
24 |
+
self.conv_K = nn.Conv2d(emb_dim,n_head*E,1)
|
25 |
+
self.norm_K = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
|
26 |
+
|
27 |
+
self.conv_V = nn.Conv2d(emb_dim, emb_dim, 1)
|
28 |
+
self.norm_V = AllHeadPReLULayerNormalization4DC((n_head, emb_dim // n_head), eps=eps)
|
29 |
+
|
30 |
+
self.concat_proj = nn.Sequential(
|
31 |
+
nn.Conv2d(emb_dim,emb_dim,1),
|
32 |
+
getattr(nn,activation)(),
|
33 |
+
LayerNormalization(emb_dim, dim=-3, total_dim=4, eps=eps),
|
34 |
+
)
|
35 |
+
self.emb_dim = emb_dim
|
36 |
+
self.n_head = n_head
|
37 |
+
def forward(self,q,kv):
|
38 |
+
"""
|
39 |
+
args:
|
40 |
+
query (torch.Tensor): a query for cross attention, come frome the reference encoder
|
41 |
+
[B D Q Tq]
|
42 |
+
kv (torch.Tensor): a key and value for cross attention, come frome the output of feature split
|
43 |
+
[B nSrc D Q Tkv]
|
44 |
+
output:
|
45 |
+
output: (torch.Tensor):[B D Q Tkv]
|
46 |
+
"""
|
47 |
+
|
48 |
+
B, D, freq, Tq = q.shape
|
49 |
+
|
50 |
+
_, nSrc, _, _, Tkv = kv.shape
|
51 |
+
if Tq >= Tkv:
|
52 |
+
q = q[:,:,:,-Tkv:]
|
53 |
+
else:
|
54 |
+
r = math.ceil(Tkv/Tq)
|
55 |
+
q = repeat(q,"B D Q T -> B D Q (T r)", r = r)
|
56 |
+
q = q[:,:,:,-Tkv:]
|
57 |
+
query = rearrange(q,"B D Q T -> B D T Q")
|
58 |
+
kvInput = rearrange(kv,"B n D Q T -> B D T (n Q)")
|
59 |
+
|
60 |
+
Q = self.norm_Q(self.conv_Q(query)) # [B, n_head, C, T, Q]
|
61 |
+
K = self.norm_K(self.conv_K(kvInput)) # [B, n_head, C, T, Q*nSrc]
|
62 |
+
V = self.norm_V(self.conv_V(kvInput))
|
63 |
+
|
64 |
+
Q = rearrange(Q, "B H C T Q -> (B H T) Q C")
|
65 |
+
K = rearrange(K, "B H C T Q -> (B H T) C Q").contiguous()
|
66 |
+
_, n_head, channel, _, _ = V.shape
|
67 |
+
V = rearrange(V, "B H C T Q -> (B H T) Q C")
|
68 |
+
|
69 |
+
emb_dim = Q.shape[-1]
|
70 |
+
qkT = torch.matmul(Q, K) / (emb_dim**0.5)
|
71 |
+
qkT = F.softmax(qkT,dim=2)
|
72 |
+
|
73 |
+
att = torch.matmul(qkT,V)
|
74 |
+
att = rearrange(att, "(B H T) Q C -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = B, T=Tkv)
|
75 |
+
att = self.concat_proj(att)
|
76 |
+
|
77 |
+
out = att + query
|
78 |
+
out = rearrange(out, "B C T Q -> B C Q T")
|
79 |
+
return out
|
80 |
+
|
81 |
+
|
82 |
+
class CrossFrameCrossAttention(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
emb_dim = 48,
|
86 |
+
n_head=4,
|
87 |
+
qk_output_channel=4,
|
88 |
+
activation="PReLU",
|
89 |
+
eps = 1e-5
|
90 |
+
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
assert emb_dim % n_head == 0
|
94 |
+
E = qk_output_channel
|
95 |
+
self.conv_Q = nn.Conv2d(emb_dim,n_head*E,1)
|
96 |
+
self.norm_Q = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
|
97 |
+
|
98 |
+
self.conv_K = nn.Conv2d(emb_dim,n_head*E,1)
|
99 |
+
self.norm_K = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
|
100 |
+
|
101 |
+
self.conv_V = nn.Conv2d(emb_dim, emb_dim, 1)
|
102 |
+
self.norm_V = AllHeadPReLULayerNormalization4DC((n_head, emb_dim // n_head), eps=eps)
|
103 |
+
|
104 |
+
self.concat_proj = nn.Sequential(
|
105 |
+
nn.Conv2d(emb_dim,emb_dim,1),
|
106 |
+
getattr(nn,activation)(),
|
107 |
+
LayerNormalization(emb_dim, dim=-3, total_dim=4, eps=eps),
|
108 |
+
)
|
109 |
+
self.emb_dim = emb_dim
|
110 |
+
self.n_head = n_head
|
111 |
+
def forward(self,q,kv):
|
112 |
+
"""
|
113 |
+
args:
|
114 |
+
query (torch.Tensor): a query for cross attention, come frome the reference encoder
|
115 |
+
[B D Q Tq]
|
116 |
+
kv (torch.Tensor): a key and value for cross attention, come frome the output of feature split
|
117 |
+
[B D Q Tkv]
|
118 |
+
output:
|
119 |
+
output: (torch.Tensor):[B D Q Tkv]
|
120 |
+
"""
|
121 |
+
Tq = q.shape[-1]
|
122 |
+
Tkv = kv.shape[-1]
|
123 |
+
if Tq >= Tkv:
|
124 |
+
q = q[:,:,:,-Tkv:]
|
125 |
+
else:
|
126 |
+
r = math.ceil(Tkv/Tq)
|
127 |
+
q = repeat(q,"B D Q T -> B D Q (T r)", r = r)
|
128 |
+
q = q[:,:,:,-Tkv:]
|
129 |
+
|
130 |
+
input = rearrange(q,"B C Q T -> B C T Q")
|
131 |
+
kvInput = rearrange(kv,"B C Q T -> B C T Q")
|
132 |
+
|
133 |
+
Q = self.norm_Q(self.conv_Q(input)) # [B, n_head, C, T, Q]
|
134 |
+
K = self.norm_K(self.conv_K(kvInput))
|
135 |
+
V = self.norm_V(self.conv_V(kvInput))
|
136 |
+
|
137 |
+
Q = rearrange(Q, "B H C T Q -> (B H) T (C Q)")
|
138 |
+
K = rearrange(K, "B H C T Q -> (B H) (C Q) T").contiguous()
|
139 |
+
batch, n_head, channel, frame, freq = V.shape
|
140 |
+
V = rearrange(V, "B H C T Q -> (B H) T (C Q)")
|
141 |
+
emb_dim = Q.shape[-1]
|
142 |
+
qkT = torch.matmul(Q, K) / (emb_dim**0.5)
|
143 |
+
qkT = F.softmax(qkT,dim=2)
|
144 |
+
att = torch.matmul(qkT,V)
|
145 |
+
att = rearrange(att, "(B H) T (C Q) -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = batch, T=frame)
|
146 |
+
att = self.concat_proj(att)
|
147 |
+
out = att + input
|
148 |
+
out = rearrange(out, "B C T Q -> B C Q T")
|
149 |
+
return out
|
150 |
+
|
151 |
+
class CrossAttentionFilter(nn.Module):
|
152 |
+
def __init__(self, emb_dim = 48) -> None:
|
153 |
+
super().__init__()
|
154 |
+
self.emb_dim = emb_dim
|
155 |
+
|
156 |
+
def forward(self, q, k, v):
|
157 |
+
"""
|
158 |
+
Args:
|
159 |
+
q (torch.Tensor): from the provious layer, [B D F T]
|
160 |
+
k (torch.Tensor): from the speaker embedidng encoder, [B D]
|
161 |
+
v (torch.Tensor): from the speaker embedidng encoder, [B D]
|
162 |
+
"""
|
163 |
+
|
164 |
+
B, D, _, T = q.shape
|
165 |
+
|
166 |
+
q = rearrange(q, "B D F T -> (B T) F D")
|
167 |
+
k = repeat(k, "B D -> (B T) D 1", T = T)
|
168 |
+
v = repeat(v, "B D -> (B T) 1 D", T = T)
|
169 |
+
|
170 |
+
qkT = torch.matmul(q, k)/(D**0.5) # [(B T) F 1]
|
171 |
+
qkT = F.softmax(qkT, dim=-1)
|
172 |
+
att = torch.matmul(qkT, v) # [(B T) F D]
|
173 |
+
att = rearrange(att, "(B T) F D -> B D F T", B = B, T = T)
|
174 |
+
return att
|
175 |
+
|
176 |
+
class CrossAttentionFilterV2(nn.Module):
|
177 |
+
def __init__(self, emb_dim = 48) -> None:
|
178 |
+
super().__init__()
|
179 |
+
self.emb_dim = emb_dim
|
180 |
+
def forward(self,q, kv):
|
181 |
+
"""
|
182 |
+
Args:
|
183 |
+
q: torch.Tensor, [B F D] a query for cross attention, come from the reference encoder (speaker embedding)
|
184 |
+
kv: torch.Tensor, [B D F T] a key and value for cross attention, come from the output of previous layer (TF gridnet)
|
185 |
+
"""
|
186 |
+
|
187 |
+
B, D, _, T = kv.shape
|
188 |
+
|
189 |
+
Q = repeat(q, "B F D -> (B T) F D", T = T)
|
190 |
+
K = rearrange(kv, "B D F T -> (B T) D F")
|
191 |
+
V = rearrange(kv, "B D F T -> (B T) F D")
|
192 |
+
|
193 |
+
qkT = torch.matmul(Q,K)/(D**0.5) #[(B T) F F]
|
194 |
+
qkT = F.softmax(qkT, dim=-1)
|
195 |
+
att = torch.matmul(qkT, V) # [(B T) F D]
|
196 |
+
att = rearrange(att, "(B T) F D -> B D F T", B = B, T = T)
|
197 |
+
return att
|
network/modules/gate_module.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class BandFilterGate(nn.Module):
|
6 |
+
def __init__(self,emb_dim=48, n_freqs = 65):
|
7 |
+
super().__init__()
|
8 |
+
self.alpha = nn.Parameter(torch.empty(1,emb_dim,n_freqs,1).to(torch.float32))
|
9 |
+
self.beta = nn.Parameter(torch.empty(1,emb_dim,n_freqs,1).to(torch.float32))
|
10 |
+
nn.init.xavier_normal_(self.alpha)
|
11 |
+
nn.init.xavier_normal_(self.beta)
|
12 |
+
def forward(self,input,filters,bias):
|
13 |
+
f = F.sigmoid(self.alpha*filters)
|
14 |
+
b = F.tanh(self.beta*bias)
|
15 |
+
return f*input + b
|
16 |
+
|
network/modules/input_tranformation.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from ..layers import STFTLayer
|
4 |
+
from ..utils import STFT_transform_type_enum
|
5 |
+
from typing import Iterable
|
6 |
+
|
7 |
+
class SimpleConv1DInput(nn.Module):
|
8 |
+
def __init__(self, in_channels, out_channels, kernel, stride = 1,
|
9 |
+
padding = 0, dilation=1, groups=1, bias=True,
|
10 |
+
padding_mode='zeros', activation:str = 'ReLU'):
|
11 |
+
super().__init__()
|
12 |
+
activation = getattr(nn,activation)
|
13 |
+
self.model = nn.Sequential(
|
14 |
+
nn.Conv1d(in_channels,out_channels,kernel,stride,padding,dilation,groups,bias,padding_mode),
|
15 |
+
activation()
|
16 |
+
)
|
17 |
+
def forward(self,input):
|
18 |
+
return self.model(input)
|
19 |
+
|
20 |
+
class STFTInput(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
n_fft: int = 128,
|
24 |
+
win_length: int = None,
|
25 |
+
hop_length: int = 64,
|
26 |
+
window="hann",
|
27 |
+
center: bool = True,
|
28 |
+
normalized: bool = False,
|
29 |
+
onesided: bool = True,
|
30 |
+
spec_transform_type: str = None,
|
31 |
+
spec_factor: float = 0.15,
|
32 |
+
spec_abs_exponent: float = 0.5,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
self.stft = STFTLayer(
|
36 |
+
n_fft,
|
37 |
+
win_length,
|
38 |
+
hop_length,
|
39 |
+
window,
|
40 |
+
center,
|
41 |
+
normalized,
|
42 |
+
onesided
|
43 |
+
)
|
44 |
+
|
45 |
+
self.spec_transform_type = spec_transform_type
|
46 |
+
self.spec_factor = spec_factor
|
47 |
+
self.spec_abs_exponent = spec_abs_exponent
|
48 |
+
|
49 |
+
self.spec_transform = lambda spec: spec
|
50 |
+
if self.spec_transform_type == STFT_transform_type_enum.exponent:
|
51 |
+
self.spec_transform = lambda spec: spec.abs() ** self.spec_abs_exponent * torch.exp(1j * spec.angle())
|
52 |
+
elif self.spec_transform_type == STFT_transform_type_enum.log:
|
53 |
+
self.spec_transform = lambda spec: torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle()) * self.spec_factor
|
54 |
+
|
55 |
+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
56 |
+
def forward(self,input):
|
57 |
+
"""
|
58 |
+
Notice that, in pytorch, the STFT does not support quantize 16 bit float, so this function
|
59 |
+
is decorated with @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
60 |
+
Args:
|
61 |
+
input (torch.Tensor): signal [Batch, Nsamples] or [Batch,channel,Nsamples]
|
62 |
+
ouputs:
|
63 |
+
spectrum (torch.Tensor): float tensor perform the spectrum with 2 channel, the first channel
|
64 |
+
is real part of spectrum, the second channel is the imaginary part of spectrum
|
65 |
+
[Batch, 2, F, T] or [Batch, 2 * channel, F, T]
|
66 |
+
"""
|
67 |
+
|
68 |
+
spectrum = self.stft(input.float())
|
69 |
+
spectrum = self.spec_transform(spectrum)
|
70 |
+
|
71 |
+
re = spectrum.real
|
72 |
+
im = spectrum.imag
|
73 |
+
|
74 |
+
if input.dim() == 2:
|
75 |
+
re = re.unsqueeze(1)
|
76 |
+
im = im.unsqueeze(1)
|
77 |
+
|
78 |
+
if input.dtype in (torch.float16, torch.bfloat16):
|
79 |
+
re = re.to(dtype=input.dtype)
|
80 |
+
im = im.to(dtype=input.dtype)
|
81 |
+
|
82 |
+
return torch.cat([re,im],dim=1)
|
83 |
+
|
84 |
+
class RMSNormalizeInput(nn.Module):
|
85 |
+
def __init__(self, dim: Iterable[int], keepdim:bool = True) -> None:
|
86 |
+
super().__init__()
|
87 |
+
self.dim = dim
|
88 |
+
self.keepdim = keepdim
|
89 |
+
def forward(self,input):
|
90 |
+
std = torch.std(input,dim=self.dim,keepdim=self.keepdim)
|
91 |
+
output = input/std
|
92 |
+
return output, std
|
network/modules/output_transformation.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from ..layers import InverseSTFTLayer, ComplexTensorLayer
|
4 |
+
from typing import Iterable, Optional
|
5 |
+
|
6 |
+
class WaveGeneratorByISTFT(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
n_fft:int = 128,
|
10 |
+
win_length: Optional[int] = None,
|
11 |
+
hop_length:int = 64,
|
12 |
+
window: str = "hann",
|
13 |
+
center: bool = True,
|
14 |
+
normalized: bool = False,
|
15 |
+
onesided: bool = True
|
16 |
+
) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.istft = InverseSTFTLayer(
|
19 |
+
n_fft,
|
20 |
+
win_length,
|
21 |
+
hop_length,
|
22 |
+
window,
|
23 |
+
center,
|
24 |
+
normalized,
|
25 |
+
onesided
|
26 |
+
)
|
27 |
+
|
28 |
+
self.float_to_complex = ComplexTensorLayer()
|
29 |
+
|
30 |
+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
31 |
+
def forward(self,input,length:int=None):
|
32 |
+
x = input
|
33 |
+
if input.dtype in (torch.float16, torch.bfloat16):
|
34 |
+
x = input.float()
|
35 |
+
if x.dtype in (torch.float32,):
|
36 |
+
x = self.float_to_complex(x)
|
37 |
+
|
38 |
+
wav = self.istft(x,length)
|
39 |
+
wav = wav.to(dtype=input.dtype)
|
40 |
+
|
41 |
+
return wav
|
42 |
+
|
43 |
+
class RMSDenormalizeOutput(nn.Module):
|
44 |
+
def __init__(self) -> None:
|
45 |
+
super().__init__()
|
46 |
+
def forward(self,input,std):
|
47 |
+
return input*std
|
network/modules/sequence_embed.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
from .tf_gridnet_modules import CrossFrameSelfAttention
|
6 |
+
|
7 |
+
class SequenceEmbed(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
emb_dim: int = 48,
|
11 |
+
n_fft: int = 128,
|
12 |
+
hidden_size: int = 192,
|
13 |
+
kernel_T: int = 5,
|
14 |
+
kernel_F: int = 5,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.n_freqs = n_fft // 2 + 1
|
19 |
+
self.emb_dim = emb_dim
|
20 |
+
|
21 |
+
self.conv = nn.Sequential(
|
22 |
+
nn.Conv2d(emb_dim*2,emb_dim*2,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2),groups=emb_dim*2),
|
23 |
+
nn.PReLU(),
|
24 |
+
nn.Conv2d(emb_dim*2,emb_dim*2,1),
|
25 |
+
nn.PReLU(),
|
26 |
+
nn.Conv2d(emb_dim*2,emb_dim*2,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2),groups=emb_dim*2),
|
27 |
+
nn.PReLU(),
|
28 |
+
nn.Conv2d(emb_dim*2,emb_dim,1),
|
29 |
+
nn.PReLU(),
|
30 |
+
)
|
31 |
+
|
32 |
+
self.linear_pre = nn.Conv1d(emb_dim*self.n_freqs,hidden_size,1)
|
33 |
+
|
34 |
+
self.lstm = nn.LSTM(
|
35 |
+
hidden_size,hidden_size,1,batch_first=True,bidirectional=True
|
36 |
+
)
|
37 |
+
|
38 |
+
self.linear = nn.Linear(hidden_size*2,emb_dim*self.n_freqs)
|
39 |
+
|
40 |
+
self.filter_gen = nn.Conv1d(emb_dim,emb_dim,1)
|
41 |
+
self.bias_gen = nn.Conv1d(emb_dim,emb_dim,1)
|
42 |
+
def forward(self,x,ref):
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
x: (B, D, F, T) input tensor from prevous layer
|
46 |
+
ref: (B, D, F, T) embedding tensor previous layer
|
47 |
+
"""
|
48 |
+
B, D, n_freq, T = x.shape
|
49 |
+
input = torch.cat([x,ref],dim=1)
|
50 |
+
input = self.conv(input)
|
51 |
+
input = rearrange(input,'B D F T -> B (D F) T')
|
52 |
+
input = self.linear_pre(input)
|
53 |
+
input = rearrange(input,'B C T -> B T C')
|
54 |
+
rnn , _ = self.lstm(input)
|
55 |
+
feature = rnn[:,0]+rnn[:,-1] # (B, 2*Hidden)
|
56 |
+
feature = self.linear(feature) # (B, D*F)
|
57 |
+
feature = rearrange(feature,'B (D F) -> B D F',D=D,F=n_freq)
|
58 |
+
f = self.filter_gen(feature)
|
59 |
+
b = self.bias_gen(feature)
|
60 |
+
|
61 |
+
return f.unsqueeze(-1), b.unsqueeze(-1)
|
62 |
+
|
63 |
+
class CrossFrameCrossAttention(CrossFrameSelfAttention):
|
64 |
+
def __init__(self, emb_dim=48, n_freqs=65, n_head=4, qk_output_channel=4, activation="PReLU", eps=0.00001):
|
65 |
+
super().__init__(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps)
|
66 |
+
|
67 |
+
def forward(self, q, kv):
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
q: (B, D, F, T) query tensor
|
71 |
+
kv: (B, D, F, T) key and value tensor
|
72 |
+
"""
|
73 |
+
|
74 |
+
input_q = rearrange(q,"B C Q T -> B C T Q")
|
75 |
+
input_kv = rearrange(kv,"B C Q T -> B C T Q")
|
76 |
+
|
77 |
+
Q = self.norm_Q(self.conv_Q(input_q))
|
78 |
+
K = self.norm_K(self.conv_K(input_kv))
|
79 |
+
V = self.norm_V(self.conv_V(input_kv))
|
80 |
+
Q = rearrange(Q, "B H C T Q -> (B H) T (C Q)")
|
81 |
+
K = rearrange(K, "B H C T Q -> (B H) (C Q) T").contiguous()
|
82 |
+
batch, n_head, channel, frame, freq = V.shape
|
83 |
+
V = rearrange(V, "B H C T Q -> (B H) T (C Q)")
|
84 |
+
emb_dim = Q.shape[-1]
|
85 |
+
qkT = torch.matmul(Q, K) / (emb_dim**0.5)
|
86 |
+
qkT = F.softmax(qkT,dim=2)
|
87 |
+
att = torch.matmul(qkT,V)
|
88 |
+
att = rearrange(att, "(B H) T (C Q) -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = batch, T=frame)
|
89 |
+
att = self.concat_proj(att)
|
90 |
+
out = att + input_q
|
91 |
+
out = rearrange(out, "B C T Q -> B C Q T")
|
92 |
+
return out
|
93 |
+
|
94 |
+
class MutualAttention(nn.Module):
|
95 |
+
def __init__(self,kernel_T=5, kernel_F=5 ,emb_dim=48, n_freqs=65, n_head=4, qk_output_channel=4, activation="PReLU", eps=0.00001):
|
96 |
+
super().__init__()
|
97 |
+
|
98 |
+
self.ref_att = CrossFrameCrossAttention(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps)
|
99 |
+
self.tar_att = CrossFrameCrossAttention(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps)
|
100 |
+
|
101 |
+
self.mt_conv = nn.Sequential(
|
102 |
+
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
|
103 |
+
nn.PReLU(),
|
104 |
+
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
|
105 |
+
nn.Sigmoid()
|
106 |
+
)
|
107 |
+
|
108 |
+
self.mr_conv = nn.Sequential(
|
109 |
+
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
|
110 |
+
nn.PReLU(),
|
111 |
+
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
|
112 |
+
nn.Sigmoid()
|
113 |
+
)
|
114 |
+
|
115 |
+
self.mtr_conv = nn.Sequential(
|
116 |
+
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
|
117 |
+
nn.PReLU(),
|
118 |
+
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
|
119 |
+
nn.PReLU()
|
120 |
+
)
|
121 |
+
def forward(self,tar,ref):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
ref: (B, D, F, T) reference tensor
|
125 |
+
tar: (B, D, F, T) target tensor
|
126 |
+
"""
|
127 |
+
|
128 |
+
mr = self.ref_att(ref,tar)
|
129 |
+
mt = self.tar_att(tar,ref)
|
130 |
+
|
131 |
+
mrt = mr + mt
|
132 |
+
|
133 |
+
mr = self.mr_conv(mr)
|
134 |
+
mt = self.mt_conv(mt)
|
135 |
+
mrt_o = self.mtr_conv(mrt)
|
136 |
+
|
137 |
+
o = mr*mt*mrt_o + mrt
|
138 |
+
return o
|
network/modules/split_modules.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
class DimensionDimAttention(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
emb_dim: int = 48,
|
10 |
+
kernel_size: int = (7,7),
|
11 |
+
dilation: int = 2,
|
12 |
+
) -> None:
|
13 |
+
super().__init__()
|
14 |
+
self.emb_dim = emb_dim
|
15 |
+
|
16 |
+
self.attn = nn.Sequential(
|
17 |
+
nn.Conv2d(2*emb_dim,2*emb_dim,1),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Conv2d(2*emb_dim,2*emb_dim,kernel_size=kernel_size,groups=2*emb_dim,padding="same"),
|
20 |
+
nn.PReLU(),
|
21 |
+
nn.Conv2d(2*emb_dim,2*emb_dim,kernel_size=kernel_size,dilation=dilation,groups=2*emb_dim,padding="same"),
|
22 |
+
nn.PReLU(),
|
23 |
+
nn.Conv2d(2*emb_dim,emb_dim,1),
|
24 |
+
nn.Sigmoid()
|
25 |
+
)
|
26 |
+
|
27 |
+
self.transform = nn.Sequential(
|
28 |
+
nn.Conv2d(emb_dim,emb_dim,1),
|
29 |
+
nn.GELU(),
|
30 |
+
nn.Conv2d(emb_dim,emb_dim,kernel_size=kernel_size,groups=emb_dim,padding="same"),
|
31 |
+
nn.PReLU(),
|
32 |
+
nn.Conv2d(emb_dim,emb_dim,kernel_size=kernel_size,dilation=dilation,groups=emb_dim,padding="same"),
|
33 |
+
nn.PReLU(),
|
34 |
+
nn.Conv2d(emb_dim,emb_dim,1)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self,x,e):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
x: (B, D, F, T) input tensor from privous layer
|
41 |
+
e: (B, D, F) embedding after reshape
|
42 |
+
"""
|
43 |
+
|
44 |
+
T = x.shape[-1]
|
45 |
+
emb = repeat(e, 'B D F -> B D F T', T=T)
|
46 |
+
att = torch.cat([x,emb],dim=1)
|
47 |
+
|
48 |
+
i = self.transform(x)
|
49 |
+
att = self.attn(att)
|
50 |
+
return i*att
|
51 |
+
|
52 |
+
class FDAttention(nn.Module):
|
53 |
+
def __init__(
|
54 |
+
self
|
55 |
+
) -> None:
|
56 |
+
super().__init__()
|
57 |
+
def forward(self,x,e):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
|
61 |
+
x: (B, D, F, T) input tensor from privous layer (for k and v)
|
62 |
+
e: (B, D, F) embedding after reshape (for q)
|
63 |
+
"""
|
64 |
+
|
65 |
+
_,D,n_freq,T = x.shape
|
66 |
+
q = repeat(e, 'B D F -> B T (D F)', T=T)
|
67 |
+
k = rearrange(x, 'B D F T -> B (D F) T')
|
68 |
+
v = rearrange(x, 'B D F T -> B T (D F)')
|
69 |
+
|
70 |
+
q = self.positional_encoding(q)
|
71 |
+
qkT = torch.matmul(q,k)/((D*n_freq)**0.5)
|
72 |
+
qkT = F.softmax(qkT,dim=-1)
|
73 |
+
att = torch.matmul(qkT,v)
|
74 |
+
|
75 |
+
att = rearrange(att, 'B T (D F) -> B D F T', D=D, F=n_freq)
|
76 |
+
return att
|
77 |
+
|
78 |
+
def positional_encoding(self, x):
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
x: (B, T, D) input to add positional encoding
|
82 |
+
"""
|
83 |
+
B, T, D = x.shape
|
84 |
+
pos = torch.arange(T, device=x.device).unsqueeze(1)
|
85 |
+
div_term = torch.exp(torch.arange(0, D, 2, device=x.device) * (-torch.log(torch.tensor(10000.0)) / D))
|
86 |
+
|
87 |
+
pos_enc = torch.zeros_like(x)
|
88 |
+
pos_enc[:, :, 0::2] = torch.sin(pos * div_term)
|
89 |
+
pos_enc[:, :, 1::2] = torch.cos(pos * div_term)
|
90 |
+
return x + pos_enc
|
91 |
+
|
92 |
+
class SplitModule(nn.Module):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
emb_dim: int = 48,
|
96 |
+
condition_dim: int = 256,
|
97 |
+
n_fft: int = 128,
|
98 |
+
) -> None:
|
99 |
+
super().__init__()
|
100 |
+
self.emb_dim = emb_dim
|
101 |
+
self.condition_dim = condition_dim
|
102 |
+
n_freq = n_fft // 2 + 1
|
103 |
+
self.n_freqs = n_freq
|
104 |
+
|
105 |
+
self.alpha = nn.Parameter(torch.empty(1,emb_dim,self.n_freqs).to(torch.float32))
|
106 |
+
self.beta = nn.Parameter(torch.empty(1,emb_dim,self.n_freqs,1).to(torch.float32))
|
107 |
+
self.d_att = DimensionDimAttention(emb_dim=emb_dim)
|
108 |
+
|
109 |
+
# self.f_att = FDAttention()
|
110 |
+
def forward(self,input,emb):
|
111 |
+
"""
|
112 |
+
Args:
|
113 |
+
input: (B, D, F, T) input tensor
|
114 |
+
emb: (B, D, F) embedding after reshape
|
115 |
+
"""
|
116 |
+
e = F.tanh(emb*self.alpha)
|
117 |
+
|
118 |
+
x = self.d_att(input,e)
|
119 |
+
# x = self.f_att(x,e)
|
120 |
+
x = x*F.sigmoid(self.beta*emb.unsqueeze(-1))
|
121 |
+
return x
|
network/modules/tf_gridnet_modules/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .dimension_embedding import DimensionEmbedding
|
2 |
+
from .tf_gridnet_block import *
|
3 |
+
from .deconv import *
|
network/modules/tf_gridnet_modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (369 Bytes). View file
|
|
network/modules/tf_gridnet_modules/__pycache__/deconv.cpython-311.pyc
ADDED
Binary file (1.6 kB). View file
|
|
network/modules/tf_gridnet_modules/__pycache__/dimension_embedding.cpython-311.pyc
ADDED
Binary file (1.63 kB). View file
|
|
network/modules/tf_gridnet_modules/__pycache__/tf_gridnet_block.cpython-311.pyc
ADDED
Binary file (15.7 kB). View file
|
|
network/modules/tf_gridnet_modules/deconv.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
class TFGridnetDeconv(nn.Module):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
emb_dim = 48,
|
8 |
+
n_srcs = 2,
|
9 |
+
kernel_size_T = 3,
|
10 |
+
kernel_size_F = 3,
|
11 |
+
padding_T = 1,
|
12 |
+
padding_F = 1,
|
13 |
+
) -> None:
|
14 |
+
super().__init__()
|
15 |
+
self.n_srcs = n_srcs
|
16 |
+
self.deconv = nn.ConvTranspose2d(emb_dim, n_srcs * 2, (kernel_size_F,kernel_size_T), padding=(padding_F,padding_T))
|
17 |
+
def forward(self,input):
|
18 |
+
output = self.deconv(input)
|
19 |
+
output = rearrange(output,"B (N C) F T -> B C N F T", C=self.n_srcs)
|
20 |
+
return output
|
network/modules/tf_gridnet_modules/dimension_embedding.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from typing import Tuple
|
3 |
+
class DimensionEmbedding(nn.Module):
|
4 |
+
def __init__(
|
5 |
+
self, audio_channel:int = 1,emb_dim:int = 48,
|
6 |
+
kernel_size: Tuple[int,int] = (3,3),
|
7 |
+
padding = "same",eps=1.0e-5
|
8 |
+
) -> None:
|
9 |
+
super().__init__()
|
10 |
+
self.emb = nn.Sequential(
|
11 |
+
nn.Conv2d(2*audio_channel, emb_dim, kernel_size,padding=padding),
|
12 |
+
nn.GroupNorm(1,emb_dim,eps=eps)
|
13 |
+
)
|
14 |
+
def forward(self,input):
|
15 |
+
return self.emb(input)
|
network/modules/tf_gridnet_modules/tf_gridnet_block.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
import math
|
6 |
+
|
7 |
+
if hasattr(torch, "bfloat16"):
|
8 |
+
HALF_PRECISION_DTYPES = (torch.float16, torch.bfloat16)
|
9 |
+
else:
|
10 |
+
HALF_PRECISION_DTYPES = (torch.float16,)
|
11 |
+
|
12 |
+
class IntraAndInterBandModule(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self, emb_dim:int = 48,
|
15 |
+
kernel_size:int = 4,
|
16 |
+
emb_hop_size:int = 1,
|
17 |
+
hidden_channels:int = 192,
|
18 |
+
eps = 1e-5
|
19 |
+
) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.emb_dim = emb_dim
|
22 |
+
self.emb_hs = emb_hop_size
|
23 |
+
self.kernel_size = kernel_size
|
24 |
+
in_channels = emb_dim * kernel_size
|
25 |
+
|
26 |
+
self.intra_norm = nn.LayerNorm(emb_dim,eps=eps)
|
27 |
+
|
28 |
+
self.intra_lstm = nn.LSTM(
|
29 |
+
in_channels,hidden_channels,1,batch_first=True,bidirectional=True
|
30 |
+
)
|
31 |
+
|
32 |
+
if kernel_size == emb_hop_size:
|
33 |
+
self.intra_linear = nn.Linear(hidden_channels*2, in_channels)
|
34 |
+
else:
|
35 |
+
self.intra_linear = nn.ConvTranspose1d(hidden_channels*2, emb_dim,kernel_size ,emb_hop_size)
|
36 |
+
|
37 |
+
self.inter_norm = nn.LayerNorm(emb_dim, eps=eps)
|
38 |
+
self.inter_lstm = nn.LSTM(
|
39 |
+
in_channels,hidden_channels,1,batch_first=True,bidirectional=True
|
40 |
+
)
|
41 |
+
|
42 |
+
if kernel_size == emb_hop_size:
|
43 |
+
self.inter_linear = nn.Linear(hidden_channels*2, in_channels)
|
44 |
+
else:
|
45 |
+
self.inter_linear = nn.ConvTranspose1d(hidden_channels*2, emb_dim,kernel_size ,emb_hop_size)
|
46 |
+
def forward(self,x):
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
input (torch.Tensor): [B C Q T]
|
50 |
+
output:
|
51 |
+
ouput (torch.Tensor): [B C Q T]
|
52 |
+
"""
|
53 |
+
B, C, old_Q, old_T = x.shape
|
54 |
+
|
55 |
+
padding = self.kernel_size - self.emb_hs
|
56 |
+
|
57 |
+
T = (
|
58 |
+
math.ceil((old_T + 2 * padding - self.kernel_size) / self.emb_hs) * self.emb_hs
|
59 |
+
+ self.kernel_size
|
60 |
+
)
|
61 |
+
Q = (
|
62 |
+
math.ceil((old_Q + 2 * padding - self.kernel_size) / self.emb_hs) * self.emb_hs
|
63 |
+
+ self.kernel_size
|
64 |
+
)
|
65 |
+
|
66 |
+
input = rearrange(x, "B C Q T -> B T Q C")
|
67 |
+
input = F.pad(input, (0, 0, padding, Q - old_Q - padding, padding, T - old_T - padding))
|
68 |
+
intra_rnn = self.intra_norm(input)
|
69 |
+
if self.kernel_size == self.emb_hs:
|
70 |
+
intra_rnn = intra_rnn.view([B * T, -1, self.kernel_size * C])
|
71 |
+
intra_rnn, _ = self.intra_lstm(intra_rnn)
|
72 |
+
intra_rnn = self.intra_linear(intra_rnn)
|
73 |
+
intra_rnn = intra_rnn.view([B, T, Q, C])
|
74 |
+
else:
|
75 |
+
intra_rnn = rearrange(intra_rnn,"B T Q C -> (B T) C Q")
|
76 |
+
intra_rnn = F.unfold(
|
77 |
+
intra_rnn[...,None],(self.kernel_size,1),stride=(self.emb_hs,1)
|
78 |
+
)
|
79 |
+
intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C*I]
|
80 |
+
intra_rnn, _ = self.intra_lstm(intra_rnn)
|
81 |
+
intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
|
82 |
+
intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
|
83 |
+
intra_rnn = intra_rnn.view([B, T, C, Q])
|
84 |
+
intra_rnn = intra_rnn.transpose(-2, -1) # [B, T, Q, C]
|
85 |
+
intra_rnn = intra_rnn + input
|
86 |
+
inter_input = rearrange(intra_rnn, "B T Q C -> B Q T C")
|
87 |
+
inter_rnn = self.inter_norm(inter_input)
|
88 |
+
if self.kernel_size == self.emb_hs:
|
89 |
+
inter_rnn = inter_rnn.view([B * Q, -1, self.kernel_size * C])
|
90 |
+
inter_rnn, _ = self.inter_lstm(inter_rnn)
|
91 |
+
inter_rnn = self.inter_linear(intra_rnn)
|
92 |
+
inter_rnn = inter_rnn.view([B, Q, T, C])
|
93 |
+
else:
|
94 |
+
inter_rnn = rearrange(inter_rnn,"B Q T C -> (B Q) C T")
|
95 |
+
inter_rnn = F.unfold(
|
96 |
+
inter_rnn[...,None],(self.kernel_size,1),stride=(self.emb_hs,1)
|
97 |
+
)
|
98 |
+
inter_rnn = inter_rnn.transpose(1, 2) # [BQ, -1, C*I]
|
99 |
+
inter_rnn,_ = self.inter_lstm(inter_rnn)
|
100 |
+
inter_rnn = inter_rnn.transpose(1, 2) # [BQ, H, -1]
|
101 |
+
inter_rnn = self.inter_linear(inter_rnn) # [BQ, C, T]
|
102 |
+
inter_rnn = inter_rnn.view([B, Q, C, T])
|
103 |
+
inter_rnn = inter_rnn.transpose(-2, -1) # [B, Q, T, C]
|
104 |
+
inter_rnn = inter_rnn + inter_input
|
105 |
+
|
106 |
+
inter_rnn = rearrange(inter_rnn,"B Q T C -> B C Q T")
|
107 |
+
inter_rnn = inter_rnn[..., padding : padding + old_Q, padding : padding + old_T]
|
108 |
+
|
109 |
+
return inter_rnn
|
110 |
+
|
111 |
+
class LayerNormalization(nn.Module):
|
112 |
+
def __init__(self, input_dim, dim=1, total_dim=4, eps=1e-5):
|
113 |
+
super().__init__()
|
114 |
+
self.dim = dim if dim >= 0 else total_dim + dim
|
115 |
+
param_size = [1 if ii != self.dim else input_dim for ii in range(total_dim)]
|
116 |
+
self.gamma = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
|
117 |
+
self.beta = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
|
118 |
+
nn.init.ones_(self.gamma)
|
119 |
+
nn.init.zeros_(self.beta)
|
120 |
+
self.eps = eps
|
121 |
+
|
122 |
+
@torch.cuda.amp.autocast(enabled=False)
|
123 |
+
def forward(self, x):
|
124 |
+
if x.ndim - 1 < self.dim:
|
125 |
+
raise ValueError(
|
126 |
+
f"Expect x to have {self.dim + 1} dimensions, but got {x.ndim}"
|
127 |
+
)
|
128 |
+
if x.dtype in HALF_PRECISION_DTYPES:
|
129 |
+
dtype = x.dtype
|
130 |
+
x = x.float()
|
131 |
+
else:
|
132 |
+
dtype = None
|
133 |
+
mu_ = x.mean(dim=self.dim, keepdim=True)
|
134 |
+
std_ = torch.sqrt(x.var(dim=self.dim, unbiased=False, keepdim=True) + self.eps)
|
135 |
+
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
|
136 |
+
return x_hat.to(dtype=dtype) if dtype else x_hat
|
137 |
+
|
138 |
+
class AllHeadPReLULayerNormalization4DC(nn.Module):
|
139 |
+
def __init__(self, input_dimension, eps=1e-5):
|
140 |
+
super().__init__()
|
141 |
+
assert len(input_dimension) == 2, input_dimension
|
142 |
+
H, E = input_dimension
|
143 |
+
param_size = [1, H, E, 1, 1]
|
144 |
+
self.gamma = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
|
145 |
+
self.beta = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
|
146 |
+
nn.init.ones_(self.gamma)
|
147 |
+
nn.init.zeros_(self.beta)
|
148 |
+
self.act = nn.PReLU(num_parameters=H, init=0.25)
|
149 |
+
self.eps = eps
|
150 |
+
self.H = H
|
151 |
+
self.E = E
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
assert x.ndim == 4
|
155 |
+
B, _, T, F = x.shape
|
156 |
+
x = x.view([B, self.H, self.E, T, F])
|
157 |
+
x = self.act(x) # [B,H,E,T,F]
|
158 |
+
stat_dim = (2,)
|
159 |
+
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,H,1,T,1]
|
160 |
+
std_ = torch.sqrt(
|
161 |
+
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
|
162 |
+
) # [B,H,1,T,1]
|
163 |
+
x = ((x - mu_) / std_) * self.gamma + self.beta # [B,H,E,T,F]
|
164 |
+
return x
|
165 |
+
|
166 |
+
class CrossFrameSelfAttention(nn.Module):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
emb_dim = 48,
|
170 |
+
n_freqs = 65,
|
171 |
+
n_head=4,
|
172 |
+
qk_output_channel=4,
|
173 |
+
activation="PReLU",
|
174 |
+
eps = 1e-5
|
175 |
+
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
assert emb_dim % n_head == 0
|
179 |
+
E = qk_output_channel
|
180 |
+
self.conv_Q = nn.Conv2d(emb_dim,n_head*E,1)
|
181 |
+
self.norm_Q = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
|
182 |
+
|
183 |
+
self.conv_K = nn.Conv2d(emb_dim,n_head*E,1)
|
184 |
+
self.norm_K = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
|
185 |
+
|
186 |
+
self.conv_V = nn.Conv2d(emb_dim, emb_dim, 1)
|
187 |
+
self.norm_V = AllHeadPReLULayerNormalization4DC((n_head, emb_dim // n_head), eps=eps)
|
188 |
+
|
189 |
+
self.concat_proj = nn.Sequential(
|
190 |
+
nn.Conv2d(emb_dim,emb_dim,1),
|
191 |
+
getattr(nn,activation)(),
|
192 |
+
LayerNormalization(emb_dim, dim=-3, total_dim=4, eps=eps),
|
193 |
+
)
|
194 |
+
self.emb_dim = emb_dim
|
195 |
+
self.n_head = n_head
|
196 |
+
def forward(self,x):
|
197 |
+
"""
|
198 |
+
arg:
|
199 |
+
x: (torch.Tensor) [B C Q T]
|
200 |
+
output:
|
201 |
+
output: (torch.Tensor) [B C Q T]
|
202 |
+
"""
|
203 |
+
|
204 |
+
input = rearrange(x,"B C Q T -> B C T Q")
|
205 |
+
Q = self.norm_Q(self.conv_Q(input)) # [B, n_head, C, T, Q]
|
206 |
+
K = self.norm_K(self.conv_K(input))
|
207 |
+
V = self.norm_V(self.conv_V(input))
|
208 |
+
|
209 |
+
Q = rearrange(Q, "B H C T Q -> (B H) T (C Q)")
|
210 |
+
K = rearrange(K, "B H C T Q -> (B H) (C Q) T").contiguous()
|
211 |
+
batch, n_head, channel, frame, freq = V.shape
|
212 |
+
V = rearrange(V, "B H C T Q -> (B H) T (C Q)")
|
213 |
+
emb_dim = Q.shape[-1]
|
214 |
+
qkT = torch.matmul(Q, K) / (emb_dim**0.5)
|
215 |
+
qkT = F.softmax(qkT,dim=2)
|
216 |
+
att = torch.matmul(qkT,V)
|
217 |
+
att = rearrange(att, "(B H) T (C Q) -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = batch, T=frame)
|
218 |
+
att = self.concat_proj(att)
|
219 |
+
out = att + input
|
220 |
+
out = rearrange(out, "B C T Q -> B C Q T")
|
221 |
+
return out
|
222 |
+
|
223 |
+
class TFGridnetBlock(nn.Module):
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
emb_dim = 48,
|
227 |
+
kernel_size:int = 4,
|
228 |
+
emb_hop_size:int = 1,
|
229 |
+
n_freqs = 65,
|
230 |
+
hidden_channels:int = 192,
|
231 |
+
n_head=4,
|
232 |
+
qk_output_channel=4,
|
233 |
+
activation="PReLU",
|
234 |
+
eps = 1e-5
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
self.tf_grid_block = nn.Sequential(
|
238 |
+
IntraAndInterBandModule(
|
239 |
+
emb_dim=emb_dim,
|
240 |
+
kernel_size=kernel_size,
|
241 |
+
emb_hop_size=emb_hop_size,
|
242 |
+
hidden_channels=hidden_channels,
|
243 |
+
eps=eps
|
244 |
+
),
|
245 |
+
CrossFrameSelfAttention(
|
246 |
+
emb_dim=emb_dim,
|
247 |
+
n_freqs=n_freqs,
|
248 |
+
n_head=n_head,
|
249 |
+
qk_output_channel=qk_output_channel,
|
250 |
+
activation=activation,
|
251 |
+
eps=eps
|
252 |
+
)
|
253 |
+
)
|
254 |
+
def forward(self,input):
|
255 |
+
return self.tf_grid_block(input)
|
network/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .error_message import ErrorMessageUtil
|
2 |
+
from .enum_declare import *
|
network/utils/enum_declare.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass
|
4 |
+
class STFT_transform_type_enum:
|
5 |
+
exponent = "exponent"
|
6 |
+
log = "log"
|
network/utils/error_message.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass
|
4 |
+
class ErrorMessageUtil:
|
5 |
+
only_support_batch_input = "we only support batch input. If you use a single input, please call .unsqueeze(0) before."
|
6 |
+
complex_format_convert = "we require the input with input.size(1) == 2"
|
7 |
+
two_input_in_the_same_shape = "2 input must be in the same shape"
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchaudio==2.0.1
|
3 |
+
numpy==1.23.5
|
4 |
+
einops==0.7.0
|
5 |
+
pandas==1.5.3
|
6 |
+
scikit-learn==1.2.2
|
7 |
+
resemblyzer
|
8 |
+
gradio
|