hieugiaosu
Add application file
7596274
raw
history blame
588 Bytes
import torch.nn as nn
import torch
import torch.nn.functional as F
class BandFilterGate(nn.Module):
def __init__(self,emb_dim=48, n_freqs = 65):
super().__init__()
self.alpha = nn.Parameter(torch.empty(1,emb_dim,n_freqs,1).to(torch.float32))
self.beta = nn.Parameter(torch.empty(1,emb_dim,n_freqs,1).to(torch.float32))
nn.init.xavier_normal_(self.alpha)
nn.init.xavier_normal_(self.beta)
def forward(self,input,filters,bias):
f = F.sigmoid(self.alpha*filters)
b = F.tanh(self.beta*bias)
return f*input + b