|
--- |
|
language: |
|
- en |
|
license: cc-by-4.0 |
|
tags: |
|
- music |
|
- art |
|
--- |
|
# Model Card for Model ID |
|
## Model Details |
|
### Model Description |
|
The model consists of a music encoder ```MERT-v1-300M```, a natural language decoder ```vicuna-7b-delta-v0```, and a linear projection laer between the two. |
|
|
|
This checkpoint of MusiLingo is developed on the MusicInstruct (MI)-short and can answer short instructions with music raw audio, such as querying about the tempo, emotion, genre, tags information. You can use the [MI](https://huggingface.co/datasets/m-a-p/Music-Instruct) dataset for the following demo |
|
|
|
|
|
### Model Sources [optional] |
|
- **Repository:** [GitHub repo](https://github.com/zihaod/MusiLingo) |
|
- **Paper [optional]:** __[MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response](https://arxiv.org/abs/2309.08730)__ |
|
<!-- - **Demo [optional]:** [More Information Needed] --> |
|
|
|
|
|
|
|
## Getting Start |
|
``` |
|
from tqdm.auto import tqdm |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from transformers import Wav2Vec2FeatureExtractor |
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
|
|
|
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
def __init__(self, stops=[], encounters=1): |
|
super().__init__() |
|
self.stops = stops |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
for stop in self.stops: |
|
if torch.all((stop == input_ids[0][-len(stop):])).item(): |
|
return True |
|
return False |
|
|
|
def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, |
|
repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000): |
|
audio = samples["audio"].cuda() |
|
audio_embeds, atts_audio = self.encode_audio(audio) |
|
if 'instruction_input' in samples: # instruction dataset |
|
#print('Instruction Batch') |
|
instruction_prompt = [] |
|
for instruction in samples['instruction_input']: |
|
prompt = '<Audio><AudioHere></Audio> ' + instruction |
|
instruction_prompt.append(self.prompt_template.format(prompt)) |
|
audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt) |
|
self.llama_tokenizer.padding_side = "right" |
|
batch_size = audio_embeds.shape[0] |
|
bos = torch.ones([batch_size, 1], |
|
dtype=torch.long, |
|
device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id |
|
bos_embeds = self.llama_model.model.embed_tokens(bos) |
|
atts_bos = atts_audio[:, :1] |
|
inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1) |
|
attention_mask = torch.cat([atts_bos, atts_audio], dim=1) |
|
outputs = self.llama_model.generate( |
|
inputs_embeds=inputs_embeds, |
|
max_new_tokens=max_new_tokens, |
|
stopping_criteria=stopping, |
|
num_beams=num_beams, |
|
do_sample=True, |
|
min_length=min_length, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=length_penalty, |
|
temperature=temperature, |
|
) |
|
output_token = outputs[0] |
|
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it |
|
output_token = output_token[1:] |
|
if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it |
|
output_token = output_token[1:] |
|
output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False) |
|
output_text = output_text.split('###')[0] # remove the stop sign '###' |
|
output_text = output_text.split('Assistant:')[-1].strip() |
|
return output_text |
|
|
|
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True) |
|
ds = CMIDataset(processor, 'path/to/MI_dataset', 'test', question_type='short') |
|
dl = DataLoader( |
|
ds, |
|
batch_size=1, |
|
num_workers=0, |
|
pin_memory=True, |
|
shuffle=False, |
|
drop_last=True, |
|
collate_fn=ds.collater |
|
) |
|
|
|
stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(), |
|
torch.tensor([2277, 29937]).cuda()])]) |
|
|
|
from transformers import AutoModel |
|
model_short = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1") |
|
|
|
for idx, sample in tqdm(enumerate(dl)): |
|
ans = answer(Musilingo_short.model, sample, stopping, length_penalty=100, temperature=0.1) |
|
txt = sample['text_input'][0] |
|
print(txt) |
|
print(and) |
|
``` |
|
|
|
# Citing This Work |
|
|
|
If you find the work useful for your research, please consider citing it using the following BibTeX entry: |
|
``` |
|
@inproceedings{deng2024musilingo, |
|
title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response}, |
|
author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil}, |
|
booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)}, |
|
year={2024}, |
|
organization={Association for Computational Linguistics} |
|
} |
|
``` |