Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,526 Bytes
fa90792 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
import numpy as np
from scipy.io.wavfile import write
import torchaudio
from audiosr.utilities.audio.audio_processing import griffin_lim
def pad_wav(waveform, segment_length):
waveform_length = waveform.shape[-1]
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
if segment_length is None or waveform_length == segment_length:
return waveform
elif waveform_length > segment_length:
return waveform[:segment_length]
elif waveform_length < segment_length:
temp_wav = np.zeros((1, segment_length))
temp_wav[:, :waveform_length] = waveform
return temp_wav
def normalize_wav(waveform):
waveform = waveform - np.mean(waveform)
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
return waveform * 0.5
def read_wav_file(filename, segment_length):
# waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
waveform, sr = torchaudio.load(filename) # Faster!!!
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
waveform = waveform.numpy()[0, ...]
waveform = normalize_wav(waveform)
waveform = waveform[None, ...]
waveform = pad_wav(waveform, segment_length)
waveform = waveform / np.max(np.abs(waveform))
waveform = 0.5 * waveform
return waveform
def get_mel_from_wav(audio, _stft):
audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
audio = torch.autograd.Variable(audio, requires_grad=False)
melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio)
melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32)
energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
return melspec, magnitudes, energy
def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
mel = torch.stack([mel])
mel_decompress = _stft.spectral_de_normalize(mel)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000
spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling
audio = griffin_lim(
torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
)
audio = audio.squeeze()
audio = audio.cpu().numpy()
audio_path = out_filename
write(audio_path, _stft.sampling_rate, audio)
|