CLAPSep / model /CLAPSep.py
AisakaMikoto's picture
Update model/CLAPSep.py
5961cfc verified
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :Waveformer-main
@File :CLAPSep.py
@IDE :PyCharm
@Author :Aisaka/Hao Ma @SDU
@Date :2024/2/28 下午1:12
'''
import torch
from torch import nn
import torchaudio
import laion_clap
from .CLAPSep_decoder import HTSAT_Decoder
import copy
import loralib as lora
from torchlibrosa import ISTFT, STFT
from torchlibrosa.stft import magphase
import librosa
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
setattr(cur_mod, tokens[-1], module)
def process_model(model, rank):
for n, module in model.named_modules():
if 'WindowAttention' in str(type(module)):
for n_, layer in module.named_modules():
if isinstance(layer, torch.nn.Linear):
lora_layer = lora.Linear(layer.in_features, layer.out_features, r=rank,
bias=hasattr(layer, 'bias'), merge_weights=True)
lora_layer.weight = layer.weight
if hasattr(layer, 'bias'):
lora_layer.bias = layer.bias
set_module(model, n+'.'+n_, lora_layer)
return model
class CLAPSep(nn.Module):
def __init__(self, model_config, CLAP_path, use_lora=True, rank=16, nfft=1024):
super().__init__()
self.resampler = torchaudio.transforms.Resample(32000, 48000)
self.clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu')
self.clap_model.load_ckpt(CLAP_path)
for p in self.clap_model.parameters():
p.requires_grad = False
self.audio_branch = copy.deepcopy(self.clap_model.model.audio_branch)
if use_lora:
process_model(self.audio_branch, rank)
self.decoder_model = HTSAT_Decoder(**model_config)
self.stft = STFT(n_fft=nfft, hop_length=320,
win_length=nfft, window='hann', center=True, pad_mode='reflect',
freeze_parameters=True)
self.istft = ISTFT(n_fft=nfft, hop_length=320,
win_length=nfft, window='hann', center=True, pad_mode='reflect',
freeze_parameters=True)
self.features = self.install_forward_hooks()
def wav_reconstruct(self, mask, mag_x, cos_x, sin_x, length):
mag_y = torch.nn.functional.relu_(mag_x * mask)
cos_y = cos_x
sin_y = sin_x
pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length)
return pred
def inference_from_data(self, mixed, embed_pos, embed_neg):
self.eval()
real, imag = self.stft(mixed)
mag, cos, sin = magphase(real, imag)
self.features.append(mag)
with torch.no_grad():
embed = torch.nn.functional.normalize(torch.concat([embed_pos, embed_neg], dim=-1), dim=-1)
self.audio_branch({"waveform": self.resampler(mixed)})
mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed)
pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1))
del self.features[:]
return pred
def install_forward_hooks(self):
features = []
def get_features_list(_, __, output):
features.append(output)
def get_features_list_basic_layer(_, __, output):
features.append(output[0])
def spectrogram_padding(_, __, out):
return torch.nn.functional.pad(out, (0, 0, 0, 1024 - out.size(2)))
self.audio_branch.spectrogram_extractor.register_forward_hook(spectrogram_padding)
self.audio_branch.patch_embed.register_forward_hook(get_features_list)
for module in self.audio_branch.layers:
module.register_forward_hook(get_features_list_basic_layer)
return features
if __name__ == '__main__':
model_config = {"lan_embed_dim": 1024,
"depths": [1, 1, 1, 1],
"embed_dim": 128,
"encoder_embed_dim": 128,
"phase": False,
"spec_factor": 8,
"d_attn": 640,
"n_masker_layer": 3,
"conv": False}
CLAP_path = "./music_audioset_epoch_15_esc_90.14.pt"
model = CLAPSep(model_config, CLAP_path)
ckpt = torch.load('best_model.ckpt', map_location='cpu')
model.load_state_dict(ckpt, strict=False)
model.eval()
audio, fs = librosa.load("./510_25.221254348754883_mixture.wav", sr=32000)
pred = model.inference_from_data(torch.tensor(audio).unsqueeze(0), pos_prompt=[''], neg_prompt=['A vehicle engine revving then powering down.'])
import soundfile as sf
sf.write('./pred.wav', pred.squeeze().numpy(), 32000)