UniMus Project: OpenJMLA
reImplementation of JMLA
Music tagging is a task to predict the tags of music recordings. However, previous music tagging research primarily focuses on close-set music tagging tasks which can not be generalized to new tags. In this work, we propose a zero-shot music tagging system modeled by a joint music and language attention (JMLA) model to address the open-set music tagging problem. The JMLA model consists of an audio encoder modeled by a pretrained masked autoencoder and a decoder modeled by a Falcon7B. We introduce preceiver resampler to convert arbitrary length audio into fixed length embeddings. We introduce dense attention connections between encoder and decoder layers to improve the information flow between the encoder and decoder layers. We collect a large-scale music and description dataset from the internet. We propose to use ChatGPT to convert the raw descriptions into formalized and diverse descriptions to train the JMLA models. Our proposed JMLA system achieves a zero-shot audio tagging accuracy of 64.82% on the GTZAN dataset, outperforming previous zero-shot systems and achieves comparable results to previous systems on the FMA and the MagnaTagATune datasets.
Requirements
- conda create -name SpectPrompt python=3.9
- pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
- pip install transformers datasets librosa einops_exts einops mmcls peft ipdb torchlibrosa
- pip install -U openmim
- mim install mmcv==1.7.1
Quickstart
Below, we provide simple examples to show how to use OpenJMLA with ๐ค Transformers.
๐ค Transformers
To use OpenJMLA for the inference, all you need to do is to input a few lines of codes as demonstrated below.
from transformers import AutoModel, AutoTokenizer
import torch
import numpy as np
model = AutoModel.from_pretrained('UniMus/OpenJMLA', trust_remote_code=True)
device = model.device
# sample rate: 16k
music_path = '/path/to/music.wav'
# 1. get logmelspectrogram
# get the file wav_to_mel.py from https://github.com/taugastcn/SpectPrompt.git
from wav_to_mel import wav_to_mel
lms = wav_to_mel(music_path)
import os
from torch.nn.utils.rnn import pad_sequence
import random
# get the file transforms.py from https://github.com/taugastcn/SpectPrompt.git
from transforms import Normalize, SpecRandomCrop, SpecPadding, SpecRepeat
transforms = [ Normalize(-4.5, 4.5), SpecRandomCrop(target_len=2992), SpecPadding(target_len=2992), SpecRepeat() ]
lms = lms.numpy()
for trans in transforms:
lms = trans(lms)
# 2. template of input
input_dic = dict()
input_dic['filenames'] = [music_path.split('/')[-1]]
input_dic['ans_crds'] = [0]
input_dic['audio_crds'] = [0]
input_dic['attention_mask'] = torch.tensor([[1, 1, 1, 1, 1]]).to(device)
input_dic['input_ids'] = torch.tensor([[1, 694, 5777, 683, 13]]).to(device)
input_dic['spectrogram'] = torch.from_numpy(lms).unsqueez(dim=0).to(device)
# 3. generation
model.eval()
gen_ids = model.forward_test(input)
gen_text = model.neck.tokenizer.batch_decode(gen_ids.clip(0))
# 4. Post-processing
# Given that the training data may contain biases, the generated texts might need some straightforward post-processing to ensure accuracy.
# In future versions, we will enhance the quality of the data.
gen_text = gen_text.split('<s>')[-1].split('\n')[0].strip()
gen_text = gen_text.replace(' in Chinese','')
gen_text = gen_text.replace(' Chinese','')
print(gen_text)
Example
music:
https://www.youtube.com/watch?v=Q_yuO8UNGmY
caption:
Instruments: Vocals, piano, strings Genre: pop Theme: Heartbreak. Mood: Melancholy. Era: Contemporary. Tempo: Fast Best scene: A small, dimly lit bar. The melancholy mood of this song will complement the stage-inspired melody.
Citation
If you find our paper and code useful in your research, please consider giving a star and citation
@article{JMLA,
title={JOINT MUSIC AND LANGUAGE ATTENTION MODELS FOR ZERO-SHOT MUSIC TAGGING},
author={Xingjian Du, Zhesong Yu, Jiaju Lin, Bilei Zhu, Qiuqiang Kong},
journal={arXiv preprint arXiv:2310.10159},
year={2023}
}
- Downloads last month
- 51