|
import torch.nn as nn |
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from .tf_gridnet_modules import CrossFrameSelfAttention |
|
|
|
class SequenceEmbed(nn.Module): |
|
def __init__( |
|
self, |
|
emb_dim: int = 48, |
|
n_fft: int = 128, |
|
hidden_size: int = 192, |
|
kernel_T: int = 5, |
|
kernel_F: int = 5, |
|
): |
|
super().__init__() |
|
|
|
self.n_freqs = n_fft // 2 + 1 |
|
self.emb_dim = emb_dim |
|
|
|
self.conv = nn.Sequential( |
|
nn.Conv2d(emb_dim*2,emb_dim*2,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2),groups=emb_dim*2), |
|
nn.PReLU(), |
|
nn.Conv2d(emb_dim*2,emb_dim*2,1), |
|
nn.PReLU(), |
|
nn.Conv2d(emb_dim*2,emb_dim*2,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2),groups=emb_dim*2), |
|
nn.PReLU(), |
|
nn.Conv2d(emb_dim*2,emb_dim,1), |
|
nn.PReLU(), |
|
) |
|
|
|
self.linear_pre = nn.Conv1d(emb_dim*self.n_freqs,hidden_size,1) |
|
|
|
self.lstm = nn.LSTM( |
|
hidden_size,hidden_size,1,batch_first=True,bidirectional=True |
|
) |
|
|
|
self.linear = nn.Linear(hidden_size*2,emb_dim*self.n_freqs) |
|
|
|
self.filter_gen = nn.Conv1d(emb_dim,emb_dim,1) |
|
self.bias_gen = nn.Conv1d(emb_dim,emb_dim,1) |
|
def forward(self,x,ref): |
|
""" |
|
Args: |
|
x: (B, D, F, T) input tensor from prevous layer |
|
ref: (B, D, F, T) embedding tensor previous layer |
|
""" |
|
B, D, n_freq, T = x.shape |
|
input = torch.cat([x,ref],dim=1) |
|
input = self.conv(input) |
|
input = rearrange(input,'B D F T -> B (D F) T') |
|
input = self.linear_pre(input) |
|
input = rearrange(input,'B C T -> B T C') |
|
rnn , _ = self.lstm(input) |
|
feature = rnn[:,0]+rnn[:,-1] |
|
feature = self.linear(feature) |
|
feature = rearrange(feature,'B (D F) -> B D F',D=D,F=n_freq) |
|
f = self.filter_gen(feature) |
|
b = self.bias_gen(feature) |
|
|
|
return f.unsqueeze(-1), b.unsqueeze(-1) |
|
|
|
class CrossFrameCrossAttention(CrossFrameSelfAttention): |
|
def __init__(self, emb_dim=48, n_freqs=65, n_head=4, qk_output_channel=4, activation="PReLU", eps=0.00001): |
|
super().__init__(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps) |
|
|
|
def forward(self, q, kv): |
|
""" |
|
Args: |
|
q: (B, D, F, T) query tensor |
|
kv: (B, D, F, T) key and value tensor |
|
""" |
|
|
|
input_q = rearrange(q,"B C Q T -> B C T Q") |
|
input_kv = rearrange(kv,"B C Q T -> B C T Q") |
|
|
|
Q = self.norm_Q(self.conv_Q(input_q)) |
|
K = self.norm_K(self.conv_K(input_kv)) |
|
V = self.norm_V(self.conv_V(input_kv)) |
|
Q = rearrange(Q, "B H C T Q -> (B H) T (C Q)") |
|
K = rearrange(K, "B H C T Q -> (B H) (C Q) T").contiguous() |
|
batch, n_head, channel, frame, freq = V.shape |
|
V = rearrange(V, "B H C T Q -> (B H) T (C Q)") |
|
emb_dim = Q.shape[-1] |
|
qkT = torch.matmul(Q, K) / (emb_dim**0.5) |
|
qkT = F.softmax(qkT,dim=2) |
|
att = torch.matmul(qkT,V) |
|
att = rearrange(att, "(B H) T (C Q) -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = batch, T=frame) |
|
att = self.concat_proj(att) |
|
out = att + input_q |
|
out = rearrange(out, "B C T Q -> B C Q T") |
|
return out |
|
|
|
class MutualAttention(nn.Module): |
|
def __init__(self,kernel_T=5, kernel_F=5 ,emb_dim=48, n_freqs=65, n_head=4, qk_output_channel=4, activation="PReLU", eps=0.00001): |
|
super().__init__() |
|
|
|
self.ref_att = CrossFrameCrossAttention(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps) |
|
self.tar_att = CrossFrameCrossAttention(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps) |
|
|
|
self.mt_conv = nn.Sequential( |
|
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)), |
|
nn.PReLU(), |
|
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)), |
|
nn.Sigmoid() |
|
) |
|
|
|
self.mr_conv = nn.Sequential( |
|
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)), |
|
nn.PReLU(), |
|
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)), |
|
nn.Sigmoid() |
|
) |
|
|
|
self.mtr_conv = nn.Sequential( |
|
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)), |
|
nn.PReLU(), |
|
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)), |
|
nn.PReLU() |
|
) |
|
def forward(self,tar,ref): |
|
""" |
|
Args: |
|
ref: (B, D, F, T) reference tensor |
|
tar: (B, D, F, T) target tensor |
|
""" |
|
|
|
mr = self.ref_att(ref,tar) |
|
mt = self.tar_att(tar,ref) |
|
|
|
mrt = mr + mt |
|
|
|
mr = self.mr_conv(mr) |
|
mt = self.mt_conv(mt) |
|
mrt_o = self.mtr_conv(mrt) |
|
|
|
o = mr*mt*mrt_o + mrt |
|
return o |