|
|
|
|
|
''' |
|
@Project :Waveformer-main |
|
@File :CLAPSep.py |
|
@IDE :PyCharm |
|
@Author :Aisaka/Hao Ma @SDU |
|
@Date :2024/2/28 下午1:12 |
|
''' |
|
|
|
import torch |
|
from torch import nn |
|
import torchaudio |
|
import laion_clap |
|
from .CLAPSep_decoder import HTSAT_Decoder |
|
import copy |
|
import loralib as lora |
|
from torchlibrosa import ISTFT, STFT |
|
from torchlibrosa.stft import magphase |
|
import librosa |
|
|
|
def set_module(model, submodule_key, module): |
|
tokens = submodule_key.split('.') |
|
sub_tokens = tokens[:-1] |
|
cur_mod = model |
|
for s in sub_tokens: |
|
cur_mod = getattr(cur_mod, s) |
|
setattr(cur_mod, tokens[-1], module) |
|
|
|
|
|
def process_model(model, rank): |
|
for n, module in model.named_modules(): |
|
if 'WindowAttention' in str(type(module)): |
|
for n_, layer in module.named_modules(): |
|
if isinstance(layer, torch.nn.Linear): |
|
lora_layer = lora.Linear(layer.in_features, layer.out_features, r=rank, |
|
bias=hasattr(layer, 'bias'), merge_weights=True) |
|
lora_layer.weight = layer.weight |
|
if hasattr(layer, 'bias'): |
|
lora_layer.bias = layer.bias |
|
set_module(model, n+'.'+n_, lora_layer) |
|
return model |
|
|
|
|
|
class CLAPSep(nn.Module): |
|
def __init__(self, model_config, CLAP_path, use_lora=True, rank=16, nfft=1024): |
|
super().__init__() |
|
self.resampler = torchaudio.transforms.Resample(32000, 48000) |
|
self.clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu') |
|
self.clap_model.load_ckpt(CLAP_path) |
|
for p in self.clap_model.parameters(): |
|
p.requires_grad = False |
|
self.audio_branch = copy.deepcopy(self.clap_model.model.audio_branch) |
|
if use_lora: |
|
process_model(self.audio_branch, rank) |
|
self.decoder_model = HTSAT_Decoder(**model_config) |
|
self.stft = STFT(n_fft=nfft, hop_length=320, |
|
win_length=nfft, window='hann', center=True, pad_mode='reflect', |
|
freeze_parameters=True) |
|
self.istft = ISTFT(n_fft=nfft, hop_length=320, |
|
win_length=nfft, window='hann', center=True, pad_mode='reflect', |
|
freeze_parameters=True) |
|
self.features = self.install_forward_hooks() |
|
|
|
def wav_reconstruct(self, mask, mag_x, cos_x, sin_x, length): |
|
mag_y = torch.nn.functional.relu_(mag_x * mask) |
|
cos_y = cos_x |
|
sin_y = sin_x |
|
pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length) |
|
return pred |
|
|
|
def inference_from_data(self, mixed, embed_pos, embed_neg): |
|
self.eval() |
|
real, imag = self.stft(mixed) |
|
mag, cos, sin = magphase(real, imag) |
|
self.features.append(mag) |
|
with torch.no_grad(): |
|
embed = torch.nn.functional.normalize(torch.concat([embed_pos, embed_neg], dim=-1), dim=-1) |
|
self.audio_branch({"waveform": self.resampler(mixed)}) |
|
mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) |
|
pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1)) |
|
del self.features[:] |
|
return pred |
|
|
|
def install_forward_hooks(self): |
|
features = [] |
|
|
|
def get_features_list(_, __, output): |
|
features.append(output) |
|
|
|
def get_features_list_basic_layer(_, __, output): |
|
features.append(output[0]) |
|
|
|
def spectrogram_padding(_, __, out): |
|
return torch.nn.functional.pad(out, (0, 0, 0, 1024 - out.size(2))) |
|
|
|
self.audio_branch.spectrogram_extractor.register_forward_hook(spectrogram_padding) |
|
self.audio_branch.patch_embed.register_forward_hook(get_features_list) |
|
for module in self.audio_branch.layers: |
|
module.register_forward_hook(get_features_list_basic_layer) |
|
return features |
|
|
|
if __name__ == '__main__': |
|
model_config = {"lan_embed_dim": 1024, |
|
"depths": [1, 1, 1, 1], |
|
"embed_dim": 128, |
|
"encoder_embed_dim": 128, |
|
"phase": False, |
|
"spec_factor": 8, |
|
"d_attn": 640, |
|
"n_masker_layer": 3, |
|
"conv": False} |
|
CLAP_path = "./music_audioset_epoch_15_esc_90.14.pt" |
|
|
|
model = CLAPSep(model_config, CLAP_path) |
|
ckpt = torch.load('best_model.ckpt', map_location='cpu') |
|
model.load_state_dict(ckpt, strict=False) |
|
model.eval() |
|
audio, fs = librosa.load("./510_25.221254348754883_mixture.wav", sr=32000) |
|
pred = model.inference_from_data(torch.tensor(audio).unsqueeze(0), pos_prompt=[''], neg_prompt=['A vehicle engine revving then powering down.']) |
|
import soundfile as sf |
|
sf.write('./pred.wav', pred.squeeze().numpy(), 32000) |