ClearVoice-SR / dataloader /meldataset.py
alibabasglab's picture
Upload 9 files
20d8cd6 verified
raw
history blame
15.6 kB
import math
import os
import random
import torch
import torch.utils.data
import numpy as np
from librosa.util import normalize
from scipy.io.wavfile import read
import scipy
import librosa
import wave
from pydub import AudioSegment
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
try:
sampling_rate, data = read(full_path)
if max(data.shape) / sampling_rate < 0.5:
return None, None
except FileNotFoundError:
print(f"File not found: {file_path}")
return None, None
except Exception as e:
print(f"An unexpected error occurred: {e}")
return None, None
if len(data.shape) > 1:
if data.shape[1] <= 2:
data = data[...,0]
else:
data = data[0,...]
return data / MAX_WAV_VALUE, sampling_rate
def get_wave_duration(file_path):
"""
Gets the duration of a WAV file in seconds.
:param file_path: Path to the WAV file.
:return: Duration of the WAV file in seconds.
"""
try:
with wave.open(file_path, 'rb') as wf:
# Get the number of frames
num_frames = wf.getnframes()
# Get the frame rate
frame_rate = wf.getframerate()
# Calculate duration
duration = num_frames / float(frame_rate)
return duration, frame_rate, num_frames
except wave.Error as e:
print(f"Error reading {file_path}: {e}")
return None, None, None
except FileNotFoundError:
print(f"File not found: {file_path}")
return None, None, None
except Exception as e:
print(f"An unexpected error occurred: {e}")
return None, None, None
def read_audio_segment(file_path, start_ms, end_ms):
"""
Reads a segment from a WAV file and returns the raw data and its properties.
:param file_path: Path to the WAV file.
:param start_ms: Start time of the segment in milliseconds.
:param end_ms: End time of the segment in milliseconds.
:return: A tuple containing the raw audio data, frame rate, sample width, and number of channels.
"""
#start_time = time.time()
try:
# Load the audio file
audio = AudioSegment.from_wav(file_path)
# Extract the segment
segment = audio[start_ms:end_ms]
# Get raw audio data
raw_data = segment.raw_data
# Get audio properties
frame_rate = segment.frame_rate
sample_width = segment.sample_width
channels = segment.channels
# Create NumPy array from the raw audio data
audio_array = np.frombuffer(raw_data, dtype=np.int16)
# If stereo, reshape the array to have a second dimension
if channels > 1:
audio_array = audio_array.reshape((-1, channels))
audio_array = audio_array[...,0]
'''
if frame_rate !=48000:
audio_array = audio_array/MAX_WAV_VALUE
audio_array = librosa.resample(audio_array, frame_rate, 48000)
audio_array = audio_array * MAX_WAV_VALUE
frame_rate = 48000
'''
#end_time = time.time()
#time_taken = end_time - start_time
#print(f"Successfully read segment from {start_ms}ms to {end_ms}ms in {time_taken:.4f} seconds")
return audio_array / MAX_WAV_VALUE#, frame_rate #, sample_width, channels
except Exception as e:
print(f"An error occurred: {e}")
return None#, None #, None, None
def resample(audio, sr_in, sr_out, target_len=None):
#audio = audio / MAX_WAV_VALUE
#audio = normalize(audio) * 0.95
if target_len is not None:
audio = scipy.signal.resample(audio, target_len)
return audio
resample_factor = sr_out / sr_in
new_samples = int(len(audio) * resample_factor)
audio = scipy.signal.resample(audio, new_samples)
return audio
def load_segment(full_path, target_sampling_rate=None, segment_size=None):
if segment_size is not None:
dur,sampling_rate,len_data = get_wave_duration(full_path)
if sampling_rate is None: return None, None
if sampling_rate < 44100: return None, None
target_dur = segment_size / target_sampling_rate
if dur < target_dur:
data, sampling_rate = load_wav(full_path)
#print(f'data_read: {data.shape}, sampling_rate: {sampling_rate}')
if data is None: return None, None
if target_sampling_rate is not None and sampling_rate != target_sampling_rate:
data = resample(data, sampling_rate, target_sampling_rate)
sampling_rate = target_sampling_rate
data = torch.FloatTensor(data)
data = data.unsqueeze(0)
data = torch.nn.functional.pad(data, (0, segment_size - data.size(1)), 'constant')
data = data.squeeze(0)
return data.numpy(), sampling_rate
else:
dur,sampling_rate,len_data = get_wave_duration(full_path)
if sampling_rate < 44100: return None, None
target_dur = segment_size / target_sampling_rate
target_len = int(target_dur * sampling_rate)
start_idx = random.randint(0, (len_data - target_len))
start_ms = start_idx / sampling_rate * 1000
end_ms = start_ms + target_dur * 1000
data = read_audio_segment(full_path, start_ms, end_ms)
#print(f'data_read: {data.shape}, sampling_rate: {sampling_rate}')
if data is None: return None, None
if target_sampling_rate is not None and sampling_rate != target_sampling_rate:
data = resample(data, sampling_rate, target_sampling_rate)
sampling_rate = target_sampling_rate
if len(data) < segment_size:
data = torch.FloatTensor(data)
data = data.unsqueeze(0)
data = torch.nn.functional.pad(data, (0, segment_size - data.size(1)), 'constant')
data = data.squeeze(0)
data = data.numpy()
else:
start_idx = random.randint(0, (len(data) - segment_size))
data = data[start_idx:start_idx+segment_size]
#print(f'data_cut: {data.shape}')
return data, sampling_rate
else:
dur,sampling_rate,len_data = get_wave_duration(full_path)
if sampling_rate is None: return None, None
if sampling_rate < 44100: return None, None
data, sampling_rate = load_wav(full_path)
if data is None: return None, None
if target_sampling_rate is not None and sampling_rate != target_sampling_rate:
data = resample(data, sampling_rate, target_sampling_rate)
sampling_rate = target_sampling_rate
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
'''
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
'''
global mel_basis, hann_window
if fmax not in mel_basis:
#mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
# sr, n_fft, n_mels=128, fmin=0.0, fmax
mel = librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
def get_dataset_filelist_org(a):
with open(a.input_training_file, 'r', encoding='utf-8') as fi:
training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
for x in fi.read().split('\n') if len(x) > 0]
with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
for x in fi.read().split('\n') if len(x) > 0]
return training_files, validation_files
def get_dataset_filelist(a):
with open(a.input_training_file, 'r', encoding='utf-8') as fi:
training_files = [x for x in fi.read().split('\n') if len(x) > 0]
with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
validation_files = [x for x in fi.read().split('\n') if len(x) > 0]
return training_files, validation_files
class MelDataset(torch.utils.data.Dataset):
def __init__(self, training_files, segment_size, n_fft, num_mels,
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
self.audio_files = training_files
random.seed(1234)
if shuffle:
random.shuffle(self.audio_files)
self.segment_size = segment_size
self.sampling_rate = sampling_rate
self.split = split
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.fmax_loss = fmax_loss
self.cached_wav = None
self.n_cache_reuse = n_cache_reuse
self._cache_ref_count = 0
self.device = device
self.fine_tuning = fine_tuning
self.base_mels_path = base_mels_path
self.supported_samples = [16000, 22050, 24000] #[4000, 8000, 16000, 22050, 24000, 32000]
#self.supported_samples = [4000, 8000] #, 16000, 22050, 24000, 32000]
def __getitem__(self, index):
filename = self.audio_files[index]
while 1:
#audio, sampling_rate = load_wav(filename)
audio, sampling_rate = load_segment(filename, self.sampling_rate, self.segment_size)
if audio is not None: break
else:
filename = self.audio_files[random.randint(0,index)]
#audio, sampling_rate = load_wav(filename)
#audio, sampling_rate = load_segment(filename, self.sampling_rate, self.segment_size)
#audio = audio / MAX_WAV_VALUE
if not self.fine_tuning:
audio = normalize(audio) * 0.95
sr_out = random.choice(self.supported_samples)
audio_down = resample(audio, self.sampling_rate, sr_out)
target_len = len(audio) #/ downsample_factor
audio_up = resample(audio_down, None, None, target_len)
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)
audio_up = torch.FloatTensor(audio_up)
audio_up = audio_up.unsqueeze(0)
mel = mel_spectrogram(audio_up, self.n_fft, self.num_mels,
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
center=False)
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
center=False)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
def __getitem__org(self, index):
filename = self.audio_files[index]
if self._cache_ref_count == 0:
while 1:
audio, sampling_rate = load_wav(filename)
if audio is not None: break
else:
filename = self.audio_files[random.randint(0,index)]
audio, sampling_rate = load_wav(filename)
audio = audio / MAX_WAV_VALUE
if not self.fine_tuning:
audio = normalize(audio) * 0.95
#self.cached_wav = audio
if sampling_rate != self.sampling_rate:
resample_factor = self.sampling_rate / sampling_rate
new_samples = int(len(audio) * resample_factor)
audio = scipy.signal.resample(audio, new_samples)#.astype(np.int16)
#raise ValueError("{} SR doesn't match target {} SR".format(
# sampling_rate, self.sampling_rate))
downsample_factor = 16000 / self.sampling_rate
new_samples = int(len(audio) * downsample_factor)
audio_down = scipy.signal.resample(audio, new_samples)
new_samples = len(audio) #/ downsample_factor
audio_up = scipy.signal.resample(audio_down, new_samples)
#print(f'audio: {audio.shape}, audio_up: {audio_up.shape}')
#min_idx = min(len(audio), len(audio_up))
#audio = audio[:min_idx]
#audio_up = audio_up[:min_idx]
self.cached_wav = audio
self.cached_wav_up = audio_up
self._cache_ref_count = self.n_cache_reuse
else:
audio = self.cached_wav
audio_up = self.cached_wav_up
self._cache_ref_count -= 1
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)
audio_up = torch.FloatTensor(audio_up)
audio_up = audio_up.unsqueeze(0)
if True:
if self.split:
if audio.size(1) >= self.segment_size:
max_audio_start = audio.size(1) - self.segment_size
audio_start = random.randint(0, max_audio_start)
audio = audio[:, audio_start:audio_start+self.segment_size]
audio_up = audio_up[:, audio_start:audio_start+self.segment_size]
else:
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
audio_up = torch.nn.functional.pad(audio_up, (0, self.segment_size - audio_up.size(1)), 'constant')
mel = mel_spectrogram(audio_up, self.n_fft, self.num_mels,
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
center=False)
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
center=False)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
def __len__(self):
return len(self.audio_files)