import torch.nn as nn import torch import torch.nn.functional as F class BandFilterGate(nn.Module): def __init__(self,emb_dim=48, n_freqs = 65): super().__init__() self.alpha = nn.Parameter(torch.empty(1,emb_dim,n_freqs,1).to(torch.float32)) self.beta = nn.Parameter(torch.empty(1,emb_dim,n_freqs,1).to(torch.float32)) nn.init.xavier_normal_(self.alpha) nn.init.xavier_normal_(self.beta) def forward(self,input,filters,bias): f = F.sigmoid(self.alpha*filters) b = F.tanh(self.beta*bias) return f*input + b