hieugiaosu
Add application file
7596274
raw
history blame
5.06 kB
import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange
from .tf_gridnet_modules import CrossFrameSelfAttention
class SequenceEmbed(nn.Module):
def __init__(
self,
emb_dim: int = 48,
n_fft: int = 128,
hidden_size: int = 192,
kernel_T: int = 5,
kernel_F: int = 5,
):
super().__init__()
self.n_freqs = n_fft // 2 + 1
self.emb_dim = emb_dim
self.conv = nn.Sequential(
nn.Conv2d(emb_dim*2,emb_dim*2,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2),groups=emb_dim*2),
nn.PReLU(),
nn.Conv2d(emb_dim*2,emb_dim*2,1),
nn.PReLU(),
nn.Conv2d(emb_dim*2,emb_dim*2,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2),groups=emb_dim*2),
nn.PReLU(),
nn.Conv2d(emb_dim*2,emb_dim,1),
nn.PReLU(),
)
self.linear_pre = nn.Conv1d(emb_dim*self.n_freqs,hidden_size,1)
self.lstm = nn.LSTM(
hidden_size,hidden_size,1,batch_first=True,bidirectional=True
)
self.linear = nn.Linear(hidden_size*2,emb_dim*self.n_freqs)
self.filter_gen = nn.Conv1d(emb_dim,emb_dim,1)
self.bias_gen = nn.Conv1d(emb_dim,emb_dim,1)
def forward(self,x,ref):
"""
Args:
x: (B, D, F, T) input tensor from prevous layer
ref: (B, D, F, T) embedding tensor previous layer
"""
B, D, n_freq, T = x.shape
input = torch.cat([x,ref],dim=1)
input = self.conv(input)
input = rearrange(input,'B D F T -> B (D F) T')
input = self.linear_pre(input)
input = rearrange(input,'B C T -> B T C')
rnn , _ = self.lstm(input)
feature = rnn[:,0]+rnn[:,-1] # (B, 2*Hidden)
feature = self.linear(feature) # (B, D*F)
feature = rearrange(feature,'B (D F) -> B D F',D=D,F=n_freq)
f = self.filter_gen(feature)
b = self.bias_gen(feature)
return f.unsqueeze(-1), b.unsqueeze(-1)
class CrossFrameCrossAttention(CrossFrameSelfAttention):
def __init__(self, emb_dim=48, n_freqs=65, n_head=4, qk_output_channel=4, activation="PReLU", eps=0.00001):
super().__init__(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps)
def forward(self, q, kv):
"""
Args:
q: (B, D, F, T) query tensor
kv: (B, D, F, T) key and value tensor
"""
input_q = rearrange(q,"B C Q T -> B C T Q")
input_kv = rearrange(kv,"B C Q T -> B C T Q")
Q = self.norm_Q(self.conv_Q(input_q))
K = self.norm_K(self.conv_K(input_kv))
V = self.norm_V(self.conv_V(input_kv))
Q = rearrange(Q, "B H C T Q -> (B H) T (C Q)")
K = rearrange(K, "B H C T Q -> (B H) (C Q) T").contiguous()
batch, n_head, channel, frame, freq = V.shape
V = rearrange(V, "B H C T Q -> (B H) T (C Q)")
emb_dim = Q.shape[-1]
qkT = torch.matmul(Q, K) / (emb_dim**0.5)
qkT = F.softmax(qkT,dim=2)
att = torch.matmul(qkT,V)
att = rearrange(att, "(B H) T (C Q) -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = batch, T=frame)
att = self.concat_proj(att)
out = att + input_q
out = rearrange(out, "B C T Q -> B C Q T")
return out
class MutualAttention(nn.Module):
def __init__(self,kernel_T=5, kernel_F=5 ,emb_dim=48, n_freqs=65, n_head=4, qk_output_channel=4, activation="PReLU", eps=0.00001):
super().__init__()
self.ref_att = CrossFrameCrossAttention(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps)
self.tar_att = CrossFrameCrossAttention(emb_dim, n_freqs, n_head, qk_output_channel, activation, eps)
self.mt_conv = nn.Sequential(
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
nn.PReLU(),
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
nn.Sigmoid()
)
self.mr_conv = nn.Sequential(
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
nn.PReLU(),
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
nn.Sigmoid()
)
self.mtr_conv = nn.Sequential(
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
nn.PReLU(),
nn.Conv2d(emb_dim,emb_dim,(kernel_F,kernel_T),padding=(kernel_F//2,kernel_T//2)),
nn.PReLU()
)
def forward(self,tar,ref):
"""
Args:
ref: (B, D, F, T) reference tensor
tar: (B, D, F, T) target tensor
"""
mr = self.ref_att(ref,tar)
mt = self.tar_att(tar,ref)
mrt = mr + mt
mr = self.mr_conv(mr)
mt = self.mt_conv(mt)
mrt_o = self.mtr_conv(mrt)
o = mr*mt*mrt_o + mrt
return o