wanna_hear_your_voice / network /modules /input_tranformation.py
hieugiaosu
Add application file
7596274
raw
history blame
3.4 kB
import torch
import torch.nn as nn
from ..layers import STFTLayer
from ..utils import STFT_transform_type_enum
from typing import Iterable
class SimpleConv1DInput(nn.Module):
def __init__(self, in_channels, out_channels, kernel, stride = 1,
padding = 0, dilation=1, groups=1, bias=True,
padding_mode='zeros', activation:str = 'ReLU'):
super().__init__()
activation = getattr(nn,activation)
self.model = nn.Sequential(
nn.Conv1d(in_channels,out_channels,kernel,stride,padding,dilation,groups,bias,padding_mode),
activation()
)
def forward(self,input):
return self.model(input)
class STFTInput(nn.Module):
def __init__(
self,
n_fft: int = 128,
win_length: int = None,
hop_length: int = 64,
window="hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
spec_transform_type: str = None,
spec_factor: float = 0.15,
spec_abs_exponent: float = 0.5,
):
super().__init__()
self.stft = STFTLayer(
n_fft,
win_length,
hop_length,
window,
center,
normalized,
onesided
)
self.spec_transform_type = spec_transform_type
self.spec_factor = spec_factor
self.spec_abs_exponent = spec_abs_exponent
self.spec_transform = lambda spec: spec
if self.spec_transform_type == STFT_transform_type_enum.exponent:
self.spec_transform = lambda spec: spec.abs() ** self.spec_abs_exponent * torch.exp(1j * spec.angle())
elif self.spec_transform_type == STFT_transform_type_enum.log:
self.spec_transform = lambda spec: torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle()) * self.spec_factor
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self,input):
"""
Notice that, in pytorch, the STFT does not support quantize 16 bit float, so this function
is decorated with @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
Args:
input (torch.Tensor): signal [Batch, Nsamples] or [Batch,channel,Nsamples]
ouputs:
spectrum (torch.Tensor): float tensor perform the spectrum with 2 channel, the first channel
is real part of spectrum, the second channel is the imaginary part of spectrum
[Batch, 2, F, T] or [Batch, 2 * channel, F, T]
"""
spectrum = self.stft(input.float())
spectrum = self.spec_transform(spectrum)
re = spectrum.real
im = spectrum.imag
if input.dim() == 2:
re = re.unsqueeze(1)
im = im.unsqueeze(1)
if input.dtype in (torch.float16, torch.bfloat16):
re = re.to(dtype=input.dtype)
im = im.to(dtype=input.dtype)
return torch.cat([re,im],dim=1)
class RMSNormalizeInput(nn.Module):
def __init__(self, dim: Iterable[int], keepdim:bool = True) -> None:
super().__init__()
self.dim = dim
self.keepdim = keepdim
def forward(self,input):
std = torch.std(input,dim=self.dim,keepdim=self.keepdim)
output = input/std
return output, std