File size: 2,307 Bytes
7596274 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import torch.nn as nn
from typing import Optional
from ..utils import ErrorMessageUtil
from einops import rearrange
class InverseSTFTLayer(nn.Module):
def __init__(
self,
n_fft:int = 128,
win_length: Optional[int] = None,
hop_length:int = 64,
window: str = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length if win_length else n_fft
self.hop_length = hop_length
self.center = center
self.normalized = normalized
self.onesided = onesided
self.window = getattr(torch,f"{window}_window")
def forward(self,input,audio_length:int):
"""STFT forward function.
Args:
input: (Batch, Freq, Frames) or (Batch, Channels, Freq, Frames)
Returns:
output: (Batch, Nsamples) or (Batch, Channel, Nsample)
Notice:
input is a complex tensor
"""
assert input.dim() == 4 or input.dim() == 3, ErrorMessageUtil.only_support_batch_input
batch_size = input.size(0)
multi_channel = (input.dim() == 4)
if multi_channel:
input = rearrange(input, "b c f t -> (b c) f t")
window = self.window(
self.win_length,
dtype = input.real.dtype,
device = input.device
)
istft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.n_fft,
hop_length=self.hop_length,
center=self.center,
window=window,
length = audio_length,
return_complex = False
)
wave = torch.istft(input,**istft_kwargs)
if multi_channel:
wave = rearrange(wave,"(b c) l -> b c l", b = batch_size)
return wave
class ComplexTensorLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(seal,input):
assert input.shape[1] == 2, ErrorMessageUtil.complex_format_convert
real = input[:,0]
imag = input[:,1]
return torch.complex(real,imag) |