|
import torch |
|
import torch.nn as nn |
|
from typing import Optional |
|
from ..utils import ErrorMessageUtil |
|
from einops import rearrange |
|
|
|
class STFTLayer(nn.Module): |
|
def __init__( |
|
self, |
|
n_fft:int = 128, |
|
win_length: Optional[int] = None, |
|
hop_length:int = 64, |
|
window: str = "hann", |
|
center: bool = True, |
|
normalized: bool = False, |
|
onesided: bool = True, |
|
pad_mode:str ="reflect" |
|
): |
|
super().__init__() |
|
self.n_fft = n_fft |
|
self.win_length = win_length if win_length else n_fft |
|
self.hop_length = hop_length |
|
self.center = center |
|
self.normalized = normalized |
|
self.onesided = onesided |
|
self.pad_mode = pad_mode |
|
self.window = getattr(torch,f"{window}_window") |
|
def forward(self,input:torch.Tensor): |
|
"""STFT forward function. |
|
Args: |
|
input: (Batch, Nsamples) or (Batch, Channel, Nsample) |
|
Returns: |
|
output: (Batch, Freq, Frames) or (Batch, Channels, Freq, Frames) |
|
Notice: |
|
output is a complex tensor |
|
""" |
|
assert input.dim() == 2 or input.dim() == 3, ErrorMessageUtil.only_support_batch_input |
|
batch_size = input.size(0) |
|
multi_channel = (input.dim() == 3) |
|
if multi_channel: |
|
input = rearrange(input, "b c l -> (b c) l") |
|
window = self.window( |
|
self.win_length, |
|
dtype = input.dtype, |
|
device = input.device |
|
) |
|
|
|
stft_kwargs = dict( |
|
n_fft=self.n_fft, |
|
win_length=self.n_fft, |
|
hop_length=self.hop_length, |
|
center=self.center, |
|
window=window, |
|
pad_mode=self.pad_mode, |
|
return_complex=True |
|
) |
|
|
|
n_pad_left = (self.n_fft - window.shape[0]) // 2 |
|
n_pad_right = self.n_fft - window.shape[0] - n_pad_left |
|
stft_kwargs["window"] = torch.cat( |
|
[torch.zeros(n_pad_left,device=input.device), window, torch.zeros(n_pad_right,device=input.device)], 0 |
|
) |
|
|
|
output = torch.stft(input,**stft_kwargs) |
|
if multi_channel: |
|
output = rearrange(output,"(b c) f t -> b c f t", b = batch_size) |
|
return output |