|
import torch |
|
import torch.nn as nn |
|
from typing import Optional |
|
from ..utils import ErrorMessageUtil |
|
from einops import rearrange |
|
|
|
class InverseSTFTLayer(nn.Module): |
|
def __init__( |
|
self, |
|
n_fft:int = 128, |
|
win_length: Optional[int] = None, |
|
hop_length:int = 64, |
|
window: str = "hann", |
|
center: bool = True, |
|
normalized: bool = False, |
|
onesided: bool = True, |
|
): |
|
super().__init__() |
|
self.n_fft = n_fft |
|
self.win_length = win_length if win_length else n_fft |
|
self.hop_length = hop_length |
|
self.center = center |
|
self.normalized = normalized |
|
self.onesided = onesided |
|
self.window = getattr(torch,f"{window}_window") |
|
def forward(self,input,audio_length:int): |
|
"""STFT forward function. |
|
Args: |
|
input: (Batch, Freq, Frames) or (Batch, Channels, Freq, Frames) |
|
Returns: |
|
output: (Batch, Nsamples) or (Batch, Channel, Nsample) |
|
|
|
Notice: |
|
input is a complex tensor |
|
""" |
|
assert input.dim() == 4 or input.dim() == 3, ErrorMessageUtil.only_support_batch_input |
|
batch_size = input.size(0) |
|
multi_channel = (input.dim() == 4) |
|
if multi_channel: |
|
input = rearrange(input, "b c f t -> (b c) f t") |
|
window = self.window( |
|
self.win_length, |
|
dtype = input.real.dtype, |
|
device = input.device |
|
) |
|
istft_kwargs = dict( |
|
n_fft=self.n_fft, |
|
win_length=self.n_fft, |
|
hop_length=self.hop_length, |
|
center=self.center, |
|
window=window, |
|
length = audio_length, |
|
return_complex = False |
|
) |
|
|
|
wave = torch.istft(input,**istft_kwargs) |
|
if multi_channel: |
|
wave = rearrange(wave,"(b c) l -> b c l", b = batch_size) |
|
return wave |
|
|
|
class ComplexTensorLayer(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
def forward(seal,input): |
|
assert input.shape[1] == 2, ErrorMessageUtil.complex_format_convert |
|
real = input[:,0] |
|
imag = input[:,1] |
|
|
|
return torch.complex(real,imag) |