Spaces:
Sleeping
Sleeping
File size: 7,944 Bytes
5548515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import librosa
import torch
import numpy as np
def extract_mstft(
audio_ref,
audio_deg,
fs=None,
mid_freq=None,
high_freq=None,
method="cut",
version="pwg",
):
"""Compute Multi-Scale STFT Distance (mstft) between the predicted and the ground truth audio.
audio_ref: path to the ground truth audio.
audio_deg: path to the predicted audio.
fs: sampling rate.
med_freq: division frequency for mid frequency parts.
high_freq: division frequency for high frequency parts.
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
"cut" will cut both audios into a same length according to the one with the shorter length.
version: "pwg" will use the computational version provided by ParallelWaveGAN.
"encodec" will use the computational version provided by Encodec.
"""
# Load audio
if fs != None:
audio_ref, _ = librosa.load(audio_ref, sr=fs)
audio_deg, _ = librosa.load(audio_deg, sr=fs)
else:
audio_ref, fs = librosa.load(audio_ref)
audio_deg, fs = librosa.load(audio_deg)
# Automatically choose mid_freq and high_freq if they are not given
if mid_freq == None:
mid_freq = fs // 6
if high_freq == None:
high_freq = fs // 3
# Audio length alignment
if len(audio_ref) != len(audio_deg):
if method == "cut":
length = min(len(audio_ref), len(audio_deg))
audio_ref = audio_ref[:length]
audio_deg = audio_deg[:length]
elif method == "dtw":
_, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
audio_ref_new = []
audio_deg_new = []
for i in range(wp.shape[0]):
ref_index = wp[i][0]
deg_index = wp[i][1]
audio_ref_new.append(audio_ref[ref_index])
audio_deg_new.append(audio_deg[deg_index])
audio_ref = np.array(audio_ref_new)
audio_deg = np.array(audio_deg_new)
assert len(audio_ref) == len(audio_deg)
# Define loss function
l1Loss = torch.nn.L1Loss(reduction="mean")
l2Loss = torch.nn.MSELoss(reduction="mean")
# Compute distance
if version == "encodec":
n_fft = 1024
mstft = 0
mstft_low = 0
mstft_mid = 0
mstft_high = 0
freq_resolution = fs / n_fft
mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution))
high_freq_index = 1 + int(np.floor(high_freq / freq_resolution))
for i in range(5, 11):
hop_length = 2**i // 4
win_length = 2**i
spec_ref = librosa.stft(
y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length
)
spec_deg = librosa.stft(
y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length
)
mag_ref = np.abs(spec_ref)
mag_deg = np.abs(spec_deg)
mag_ref = torch.from_numpy(mag_ref)
mag_deg = torch.from_numpy(mag_deg)
mstft += l1Loss(mag_ref, mag_deg) + l2Loss(mag_ref, mag_deg)
mag_ref_low = mag_ref[:mid_freq_index, :]
mag_deg_low = mag_deg[:mid_freq_index, :]
mstft_low += l1Loss(mag_ref_low, mag_deg_low) + l2Loss(
mag_ref_low, mag_deg_low
)
mag_ref_mid = mag_ref[mid_freq_index:high_freq_index, :]
mag_deg_mid = mag_deg[mid_freq_index:high_freq_index, :]
mstft_mid += l1Loss(mag_ref_mid, mag_deg_mid) + l2Loss(
mag_ref_mid, mag_deg_mid
)
mag_ref_high = mag_ref[high_freq_index:, :]
mag_deg_high = mag_deg[high_freq_index:, :]
mstft_high += l1Loss(mag_ref_high, mag_deg_high) + l2Loss(
mag_ref_high, mag_deg_high
)
mstft /= 6
mstft_low /= 6
mstft_mid /= 6
mstft_high /= 6
return mstft
elif version == "pwg":
fft_sizes = [1024, 2048, 512]
hop_sizes = [120, 240, 50]
win_sizes = [600, 1200, 240]
audio_ref = torch.from_numpy(audio_ref)
audio_deg = torch.from_numpy(audio_deg)
mstft_sc = 0
mstft_sc_low = 0
mstft_sc_mid = 0
mstft_sc_high = 0
mstft_mag = 0
mstft_mag_low = 0
mstft_mag_mid = 0
mstft_mag_high = 0
for n_fft, hop_length, win_length in zip(fft_sizes, hop_sizes, win_sizes):
spec_ref = torch.stft(
audio_ref, n_fft, hop_length, win_length, return_complex=False
)
spec_deg = torch.stft(
audio_deg, n_fft, hop_length, win_length, return_complex=False
)
real_ref = spec_ref[..., 0]
imag_ref = spec_ref[..., 1]
real_deg = spec_deg[..., 0]
imag_deg = spec_deg[..., 1]
mag_ref = torch.sqrt(
torch.clamp(real_ref**2 + imag_ref**2, min=1e-7)
).transpose(1, 0)
mag_deg = torch.sqrt(
torch.clamp(real_deg**2 + imag_deg**2, min=1e-7)
).transpose(1, 0)
sc_loss = torch.norm(mag_ref - mag_deg, p="fro") / torch.norm(
mag_ref, p="fro"
)
mag_loss = l1Loss(torch.log(mag_ref), torch.log(mag_deg))
mstft_sc += sc_loss
mstft_mag += mag_loss
freq_resolution = fs / n_fft
mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution))
high_freq_index = 1 + int(np.floor(high_freq / freq_resolution))
mag_ref_low = mag_ref[:, :mid_freq_index]
mag_deg_low = mag_deg[:, :mid_freq_index]
sc_loss_low = torch.norm(mag_ref_low - mag_deg_low, p="fro") / torch.norm(
mag_ref_low, p="fro"
)
mag_loss_low = l1Loss(torch.log(mag_ref_low), torch.log(mag_deg_low))
mstft_sc_low += sc_loss_low
mstft_mag_low += mag_loss_low
mag_ref_mid = mag_ref[:, mid_freq_index:high_freq_index]
mag_deg_mid = mag_deg[:, mid_freq_index:high_freq_index]
sc_loss_mid = torch.norm(mag_ref_mid - mag_deg_mid, p="fro") / torch.norm(
mag_ref_mid, p="fro"
)
mag_loss_mid = l1Loss(torch.log(mag_ref_mid), torch.log(mag_deg_mid))
mstft_sc_mid += sc_loss_mid
mstft_mag_mid += mag_loss_mid
mag_ref_high = mag_ref[:, high_freq_index:]
mag_deg_high = mag_deg[:, high_freq_index:]
sc_loss_high = torch.norm(
mag_ref_high - mag_deg_high, p="fro"
) / torch.norm(mag_ref_high, p="fro")
mag_loss_high = l1Loss(torch.log(mag_ref_high), torch.log(mag_deg_high))
mstft_sc_high += sc_loss_high
mstft_mag_high += mag_loss_high
# Normalize distances
mstft_sc /= len(fft_sizes)
mstft_sc_low /= len(fft_sizes)
mstft_sc_mid /= len(fft_sizes)
mstft_sc_high /= len(fft_sizes)
mstft_mag /= len(fft_sizes)
mstft_mag_low /= len(fft_sizes)
mstft_mag_mid /= len(fft_sizes)
mstft_mag_high /= len(fft_sizes)
# return (
# mstft_sc.numpy().tolist(),
# mstft_sc_low.numpy().tolist(),
# mstft_sc_mid.numpy().tolist(),
# mstft_sc_high.numpy().tolist(),
# mstft_mag.numpy().tolist(),
# mstft_mag_low.numpy().tolist(),
# mstft_mag_mid.numpy().tolist(),
# mstft_mag_high.numpy().tolist(),
# )
return mstft_sc.numpy().tolist() + mstft_mag.numpy().tolist()
|