OpenSound's picture
Upload 211 files
9d3cb0a verified
import os
import numpy as np
import torch
from .. import AudioSignal
def stoi(
estimates: AudioSignal,
references: AudioSignal,
extended: int = False,
):
"""Short term objective intelligibility
Computes the STOI (See [1][2]) of a denoised signal compared to a clean
signal, The output is expected to have a monotonic relation with the
subjective speech-intelligibility, where a higher score denotes better
speech intelligibility. Uses pystoi under the hood.
Parameters
----------
estimates : AudioSignal
Denoised speech
references : AudioSignal
Clean original speech
extended : int, optional
Boolean, whether to use the extended STOI described in [3], by default False
Returns
-------
Tensor[float]
Short time objective intelligibility measure between clean and
denoised speech
References
----------
1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
Objective Intelligibility Measure for Time-Frequency Weighted Noisy
Speech', ICASSP 2010, Texas, Dallas.
2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
IEEE Transactions on Audio, Speech, and Language Processing, 2011.
3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
Intelligibility of Speech Masked by Modulated Noise Maskers',
IEEE Transactions on Audio, Speech and Language Processing, 2016.
"""
import pystoi
estimates = estimates.clone().to_mono()
references = references.clone().to_mono()
stois = []
for i in range(estimates.batch_size):
_stoi = pystoi.stoi(
references.audio_data[i, 0].detach().cpu().numpy(),
estimates.audio_data[i, 0].detach().cpu().numpy(),
references.sample_rate,
extended=extended,
)
stois.append(_stoi)
return torch.from_numpy(np.array(stois))
def pesq(
estimates: AudioSignal,
references: AudioSignal,
mode: str = "wb",
target_sr: float = 16000,
):
"""_summary_
Parameters
----------
estimates : AudioSignal
Degraded AudioSignal
references : AudioSignal
Reference AudioSignal
mode : str, optional
'wb' (wide-band) or 'nb' (narrow-band), by default "wb"
target_sr : float, optional
Target sample rate, by default 16000
Returns
-------
Tensor[float]
PESQ score: P.862.2 Prediction (MOS-LQO)
"""
from pesq import pesq as pesq_fn
estimates = estimates.clone().to_mono().resample(target_sr)
references = references.clone().to_mono().resample(target_sr)
pesqs = []
for i in range(estimates.batch_size):
_pesq = pesq_fn(
estimates.sample_rate,
references.audio_data[i, 0].detach().cpu().numpy(),
estimates.audio_data[i, 0].detach().cpu().numpy(),
mode,
)
pesqs.append(_pesq)
return torch.from_numpy(np.array(pesqs))
def visqol(
estimates: AudioSignal,
references: AudioSignal,
mode: str = "audio",
): # pragma: no cover
"""ViSQOL score.
Parameters
----------
estimates : AudioSignal
Degraded AudioSignal
references : AudioSignal
Reference AudioSignal
mode : str, optional
'audio' or 'speech', by default 'audio'
Returns
-------
Tensor[float]
ViSQOL score (MOS-LQO)
"""
from visqol import visqol_lib_py
from visqol.pb2 import visqol_config_pb2
from visqol.pb2 import similarity_result_pb2
config = visqol_config_pb2.VisqolConfig()
if mode == "audio":
target_sr = 48000
config.options.use_speech_scoring = False
svr_model_path = "libsvm_nu_svr_model.txt"
elif mode == "speech":
target_sr = 16000
config.options.use_speech_scoring = True
svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
else:
raise ValueError(f"Unrecognized mode: {mode}")
config.audio.sample_rate = target_sr
config.options.svr_model_path = os.path.join(
os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path
)
api = visqol_lib_py.VisqolApi()
api.Create(config)
estimates = estimates.clone().to_mono().resample(target_sr)
references = references.clone().to_mono().resample(target_sr)
visqols = []
for i in range(estimates.batch_size):
_visqol = api.Measure(
references.audio_data[i, 0].detach().cpu().numpy().astype(float),
estimates.audio_data[i, 0].detach().cpu().numpy().astype(float),
)
visqols.append(_visqol.moslqo)
return torch.from_numpy(np.array(visqols))