File size: 2,307 Bytes
7596274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
from typing import Optional
from ..utils import ErrorMessageUtil
from einops import rearrange

class InverseSTFTLayer(nn.Module):
    def __init__(
            self,
            n_fft:int = 128,
            win_length: Optional[int] = None,
            hop_length:int = 64,
            window: str = "hann",
            center: bool = True,
            normalized: bool = False,
            onesided: bool = True,
            ):
        super().__init__()
        self.n_fft = n_fft
        self.win_length = win_length if win_length else n_fft
        self.hop_length = hop_length
        self.center = center
        self.normalized = normalized
        self.onesided = onesided
        self.window = getattr(torch,f"{window}_window")
    def forward(self,input,audio_length:int):
        """STFT forward function.
        Args:
            input: (Batch, Freq, Frames) or (Batch, Channels, Freq, Frames)
        Returns:
            output: (Batch, Nsamples) or (Batch, Channel, Nsample)
            
        Notice:
            input is a complex tensor
        """
        assert input.dim() == 4 or input.dim() == 3, ErrorMessageUtil.only_support_batch_input
        batch_size = input.size(0)
        multi_channel = (input.dim() == 4)
        if multi_channel:
            input = rearrange(input, "b c f t -> (b c) f t")
        window = self.window(
                    self.win_length,
                    dtype = input.real.dtype,
                    device = input.device
                )
        istft_kwargs = dict(
                n_fft=self.n_fft,
                win_length=self.n_fft,
                hop_length=self.hop_length,
                center=self.center,
                window=window,
                length = audio_length,
                return_complex = False
            )
        
        wave = torch.istft(input,**istft_kwargs)
        if multi_channel:
            wave = rearrange(wave,"(b c) l -> b c l", b = batch_size)
        return wave
    
class ComplexTensorLayer(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(seal,input):
        assert input.shape[1] == 2, ErrorMessageUtil.complex_format_convert
        real = input[:,0]
        imag = input[:,1]

        return torch.complex(real,imag)