hieugiaosu
Add application file
7596274
raw
history blame
7.05 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .tf_gridnet_modules import AllHeadPReLULayerNormalization4DC, LayerNormalization
from einops import rearrange, repeat
import math
class IntraFrameCrossAttention(nn.Module):
def __init__(
self,
emb_dim = 48,
n_head = 4,
qk_output_channel=12,
activation="PReLU",
eps = 1e-5
):
super().__init__()
assert emb_dim % n_head == 0
E = qk_output_channel
self.conv_Q = nn.Conv2d(emb_dim,n_head*E,1)
self.norm_Q = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
self.conv_K = nn.Conv2d(emb_dim,n_head*E,1)
self.norm_K = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
self.conv_V = nn.Conv2d(emb_dim, emb_dim, 1)
self.norm_V = AllHeadPReLULayerNormalization4DC((n_head, emb_dim // n_head), eps=eps)
self.concat_proj = nn.Sequential(
nn.Conv2d(emb_dim,emb_dim,1),
getattr(nn,activation)(),
LayerNormalization(emb_dim, dim=-3, total_dim=4, eps=eps),
)
self.emb_dim = emb_dim
self.n_head = n_head
def forward(self,q,kv):
"""
args:
query (torch.Tensor): a query for cross attention, come frome the reference encoder
[B D Q Tq]
kv (torch.Tensor): a key and value for cross attention, come frome the output of feature split
[B nSrc D Q Tkv]
output:
output: (torch.Tensor):[B D Q Tkv]
"""
B, D, freq, Tq = q.shape
_, nSrc, _, _, Tkv = kv.shape
if Tq >= Tkv:
q = q[:,:,:,-Tkv:]
else:
r = math.ceil(Tkv/Tq)
q = repeat(q,"B D Q T -> B D Q (T r)", r = r)
q = q[:,:,:,-Tkv:]
query = rearrange(q,"B D Q T -> B D T Q")
kvInput = rearrange(kv,"B n D Q T -> B D T (n Q)")
Q = self.norm_Q(self.conv_Q(query)) # [B, n_head, C, T, Q]
K = self.norm_K(self.conv_K(kvInput)) # [B, n_head, C, T, Q*nSrc]
V = self.norm_V(self.conv_V(kvInput))
Q = rearrange(Q, "B H C T Q -> (B H T) Q C")
K = rearrange(K, "B H C T Q -> (B H T) C Q").contiguous()
_, n_head, channel, _, _ = V.shape
V = rearrange(V, "B H C T Q -> (B H T) Q C")
emb_dim = Q.shape[-1]
qkT = torch.matmul(Q, K) / (emb_dim**0.5)
qkT = F.softmax(qkT,dim=2)
att = torch.matmul(qkT,V)
att = rearrange(att, "(B H T) Q C -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = B, T=Tkv)
att = self.concat_proj(att)
out = att + query
out = rearrange(out, "B C T Q -> B C Q T")
return out
class CrossFrameCrossAttention(nn.Module):
def __init__(
self,
emb_dim = 48,
n_head=4,
qk_output_channel=4,
activation="PReLU",
eps = 1e-5
):
super().__init__()
assert emb_dim % n_head == 0
E = qk_output_channel
self.conv_Q = nn.Conv2d(emb_dim,n_head*E,1)
self.norm_Q = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
self.conv_K = nn.Conv2d(emb_dim,n_head*E,1)
self.norm_K = AllHeadPReLULayerNormalization4DC((n_head, E), eps=eps)
self.conv_V = nn.Conv2d(emb_dim, emb_dim, 1)
self.norm_V = AllHeadPReLULayerNormalization4DC((n_head, emb_dim // n_head), eps=eps)
self.concat_proj = nn.Sequential(
nn.Conv2d(emb_dim,emb_dim,1),
getattr(nn,activation)(),
LayerNormalization(emb_dim, dim=-3, total_dim=4, eps=eps),
)
self.emb_dim = emb_dim
self.n_head = n_head
def forward(self,q,kv):
"""
args:
query (torch.Tensor): a query for cross attention, come frome the reference encoder
[B D Q Tq]
kv (torch.Tensor): a key and value for cross attention, come frome the output of feature split
[B D Q Tkv]
output:
output: (torch.Tensor):[B D Q Tkv]
"""
Tq = q.shape[-1]
Tkv = kv.shape[-1]
if Tq >= Tkv:
q = q[:,:,:,-Tkv:]
else:
r = math.ceil(Tkv/Tq)
q = repeat(q,"B D Q T -> B D Q (T r)", r = r)
q = q[:,:,:,-Tkv:]
input = rearrange(q,"B C Q T -> B C T Q")
kvInput = rearrange(kv,"B C Q T -> B C T Q")
Q = self.norm_Q(self.conv_Q(input)) # [B, n_head, C, T, Q]
K = self.norm_K(self.conv_K(kvInput))
V = self.norm_V(self.conv_V(kvInput))
Q = rearrange(Q, "B H C T Q -> (B H) T (C Q)")
K = rearrange(K, "B H C T Q -> (B H) (C Q) T").contiguous()
batch, n_head, channel, frame, freq = V.shape
V = rearrange(V, "B H C T Q -> (B H) T (C Q)")
emb_dim = Q.shape[-1]
qkT = torch.matmul(Q, K) / (emb_dim**0.5)
qkT = F.softmax(qkT,dim=2)
att = torch.matmul(qkT,V)
att = rearrange(att, "(B H) T (C Q) -> B (H C) T Q", C=channel, Q=freq, H = n_head, B = batch, T=frame)
att = self.concat_proj(att)
out = att + input
out = rearrange(out, "B C T Q -> B C Q T")
return out
class CrossAttentionFilter(nn.Module):
def __init__(self, emb_dim = 48) -> None:
super().__init__()
self.emb_dim = emb_dim
def forward(self, q, k, v):
"""
Args:
q (torch.Tensor): from the provious layer, [B D F T]
k (torch.Tensor): from the speaker embedidng encoder, [B D]
v (torch.Tensor): from the speaker embedidng encoder, [B D]
"""
B, D, _, T = q.shape
q = rearrange(q, "B D F T -> (B T) F D")
k = repeat(k, "B D -> (B T) D 1", T = T)
v = repeat(v, "B D -> (B T) 1 D", T = T)
qkT = torch.matmul(q, k)/(D**0.5) # [(B T) F 1]
qkT = F.softmax(qkT, dim=-1)
att = torch.matmul(qkT, v) # [(B T) F D]
att = rearrange(att, "(B T) F D -> B D F T", B = B, T = T)
return att
class CrossAttentionFilterV2(nn.Module):
def __init__(self, emb_dim = 48) -> None:
super().__init__()
self.emb_dim = emb_dim
def forward(self,q, kv):
"""
Args:
q: torch.Tensor, [B F D] a query for cross attention, come from the reference encoder (speaker embedding)
kv: torch.Tensor, [B D F T] a key and value for cross attention, come from the output of previous layer (TF gridnet)
"""
B, D, _, T = kv.shape
Q = repeat(q, "B F D -> (B T) F D", T = T)
K = rearrange(kv, "B D F T -> (B T) D F")
V = rearrange(kv, "B D F T -> (B T) F D")
qkT = torch.matmul(Q,K)/(D**0.5) #[(B T) F F]
qkT = F.softmax(qkT, dim=-1)
att = torch.matmul(qkT, V) # [(B T) F D]
att = rearrange(att, "(B T) F D -> B D F T", B = B, T = T)
return att