### # Author: Kai Li # Date: 2021-06-09 16:43:09 # LastEditors: Please set LastEditors # LastEditTime: 2024-01-24 00:00:52 ### import torch from torch.nn.modules.loss import _Loss def freq_MAE(output, target): loss = 0. eps = torch.finfo(torch.float32).eps all_win = [32, 64, 128, 256, 512, 1024, 2048] for win in all_win: est_spec = torch.stft(output.view(-1, output.shape[-1]), n_fft=win, hop_length=win//2, window=torch.hann_window(win).to(output.device).float(), return_complex=True) target_spec = torch.stft(target.view(-1, target.shape[-1]), n_fft=win, hop_length=win//2, window=torch.hann_window(win).to(target.device).float(), return_complex=True) loss = loss + (est_spec.abs() - target_spec.abs()).abs().mean() / (target_spec.abs().mean() + eps) return loss / len(all_win) class MultiFrequencyDisLoss(_Loss): def __init__(self, eps=1e-8): super(MultiFrequencyDisLoss, self).__init__() def forward(self, target_outputs, est_outputs): D_real = 0 D_fake = 0 for i in range(len(target_outputs)): D_real = D_real + (target_outputs[i] - 1).pow(2).mean() / len(target_outputs) D_fake = D_fake + (est_outputs[i]).pow(2).mean() / len(est_outputs) return D_real + D_fake class MultiFrequencyGenLoss(_Loss): def __init__(self, eps=1e-8): super(MultiFrequencyGenLoss, self).__init__() self.eps = eps def forward(self, est_outputs, est_feature_maps, targets_feature_maps, output, ori_data): G_fake = 0 feature_matching = 0 eps = self.eps for i in range(len(est_outputs)): G_fake = G_fake + (est_outputs[i] - 1).pow(2).mean() / len(est_outputs) for j in range(len(est_feature_maps[i])): feature_matching = feature_matching + (est_feature_maps[i][j] - targets_feature_maps[i][j].detach()).abs().mean() / (targets_feature_maps[i][j].detach().abs().mean() + eps) feature_matching = feature_matching / (len(est_outputs) * len(est_feature_maps[0])) freq_loss = freq_MAE(output, ori_data.unsqueeze(1)) total_loss = freq_loss + G_fake + feature_matching return total_loss