jamimulgrave commited on
Commit
c961996
·
1 Parent(s): 7b6887d

Upload 10 files

Browse files
code/LyricsCommentData.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import os
3
+
4
+
5
+ @dataclass
6
+ class LyricsCommentData(object):
7
+ music4all_id: str
8
+ songmeanings_id: str
9
+ lyrics: str
10
+ comment: str
11
+
12
+ def get_audio_path(self): # get audio path from id
13
+ self.audio_path = os.path.join("Music4All/music4all/audios",
14
+ self.music4all_id + '.mp3'
15
+ )
16
+ return self.audio_path
code/attention_modules.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Code adopted from https://github.com/huggingface/pytorch-pretrained-BERT
3
+
4
+ import math
5
+ import copy
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+
10
+
11
+ # Gelu
12
+ def gelu(x):
13
+ """Implementation of the gelu activation function.
14
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
15
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
16
+ Also see https://arxiv.org/abs/1606.08415
17
+ """
18
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
19
+
20
+
21
+ # LayerNorm
22
+ try:
23
+ from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
24
+ except ImportError:
25
+ # print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
26
+ class BertLayerNorm(nn.Module):
27
+ def __init__(self, hidden_size, eps=1e-12):
28
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
29
+ """
30
+ super(BertLayerNorm, self).__init__()
31
+ self.weight = nn.Parameter(torch.ones(hidden_size))
32
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
33
+ self.variance_epsilon = eps
34
+
35
+ def forward(self, x):
36
+ u = x.mean(-1, keepdim=True)
37
+ s = (x - u).pow(2).mean(-1, keepdim=True)
38
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
39
+ return self.weight * x + self.bias
40
+
41
+
42
+ class BertConfig(object):
43
+ def __init__(self,
44
+ vocab_size,
45
+ hidden_size=768,
46
+ num_hidden_layers=12,
47
+ num_attention_heads=12,
48
+ intermediate_size=3072,
49
+ hidden_act="gelu",
50
+ hidden_dropout_prob=0.1,
51
+ max_position_embeddings=512,
52
+ attention_probs_dropout_prob=0.1,
53
+ type_vocab_size=2):
54
+ self.vocab_size = vocab_size
55
+ self.hidden_size = hidden_size
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.num_attention_heads = num_attention_heads
58
+ self.hidden_act = hidden_act
59
+ self.intermediate_size = intermediate_size
60
+ self.hidden_dropout_prob = hidden_dropout_prob
61
+ self.max_position_embeddings = max_position_embeddings
62
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
63
+ self.type_vocab_size = type_vocab_size
64
+
65
+
66
+ class BertSelfAttention(nn.Module):
67
+ def __init__(self, config):
68
+ super(BertSelfAttention, self).__init__()
69
+ if config.hidden_size % config.num_attention_heads != 0:
70
+ raise ValueError(
71
+ "The hidden size (%d) is not a multiple of the number of attention "
72
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
73
+ self.num_attention_heads = config.num_attention_heads
74
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
75
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
76
+
77
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
78
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
79
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
80
+
81
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
82
+
83
+ def transpose_for_scores(self, x):
84
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
85
+ x = x.view(*new_x_shape)
86
+ return x.permute(0, 2, 1, 3)
87
+
88
+ def forward(self, hidden_states, attention_mask):
89
+ mixed_query_layer = self.query(hidden_states)
90
+ mixed_key_layer = self.key(hidden_states)
91
+ mixed_value_layer = self.value(hidden_states)
92
+
93
+ query_layer = self.transpose_for_scores(mixed_query_layer)
94
+ key_layer = self.transpose_for_scores(mixed_key_layer)
95
+ value_layer = self.transpose_for_scores(mixed_value_layer)
96
+
97
+ # Take the dot product between "query" and "key" to get the raw attention scores.
98
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
99
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
100
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
101
+ if attention_mask is not None:
102
+ attention_scores = attention_scores + attention_mask
103
+
104
+ # Normalize the attention scores to probabilities.
105
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
106
+
107
+ # This is actually dropping out entire tokens to attend to, which might
108
+ # seem a bit unusual, but is taken from the original Transformer paper.
109
+ attention_probs = self.dropout(attention_probs)
110
+
111
+ context_layer = torch.matmul(attention_probs, value_layer)
112
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
113
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
114
+ context_layer = context_layer.view(*new_context_layer_shape)
115
+ return context_layer
116
+
117
+
118
+ class BertSelfOutput(nn.Module):
119
+ def __init__(self, config):
120
+ super(BertSelfOutput, self).__init__()
121
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
122
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
123
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
124
+
125
+ def forward(self, hidden_states, input_tensor):
126
+ hidden_states = self.dense(hidden_states)
127
+ hidden_states = self.dropout(hidden_states)
128
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
129
+ return hidden_states
130
+
131
+
132
+ class BertAttention(nn.Module):
133
+ def __init__(self, config):
134
+ super(BertAttention, self).__init__()
135
+ self.self = BertSelfAttention(config)
136
+ self.output = BertSelfOutput(config)
137
+
138
+ def forward(self, input_tensor, attention_mask):
139
+ self_output = self.self(input_tensor, attention_mask)
140
+ attention_output = self.output(self_output, input_tensor)
141
+ return attention_output
142
+
143
+
144
+ class BertIntermediate(nn.Module):
145
+ def __init__(self, config):
146
+ super(BertIntermediate, self).__init__()
147
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
148
+ self.intermediate_act_fn = gelu
149
+
150
+ def forward(self, hidden_states):
151
+ hidden_states = self.dense(hidden_states)
152
+ hidden_states = self.intermediate_act_fn(hidden_states)
153
+ return hidden_states
154
+
155
+
156
+ class BertOutput(nn.Module):
157
+ def __init__(self, config):
158
+ super(BertOutput, self).__init__()
159
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
160
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
161
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
162
+
163
+ def forward(self, hidden_states, input_tensor):
164
+ hidden_states = self.dense(hidden_states)
165
+ hidden_states = self.dropout(hidden_states)
166
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
167
+ return hidden_states
168
+
169
+
170
+ class BertLayer(nn.Module):
171
+ def __init__(self, config):
172
+ super(BertLayer, self).__init__()
173
+ self.attention = BertAttention(config)
174
+ self.intermediate = BertIntermediate(config)
175
+ self.output = BertOutput(config)
176
+
177
+ def forward(self, hidden_states, attention_mask):
178
+ attention_output = self.attention(hidden_states, attention_mask)
179
+ intermediate_output = self.intermediate(attention_output)
180
+ layer_output = self.output(intermediate_output, attention_output)
181
+ return layer_output
182
+
183
+
184
+ class BertEncoder(nn.Module):
185
+ def __init__(self, config):
186
+ super(BertEncoder, self).__init__()
187
+ layer = BertLayer(config)
188
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
189
+
190
+ def forward(self, hidden_states, attention_mask=None, output_all_encoded_layers=True):
191
+ all_encoder_layers = []
192
+ for layer_module in self.layer:
193
+ hidden_states = layer_module(hidden_states, attention_mask)
194
+ if output_all_encoded_layers:
195
+ all_encoder_layers.append(hidden_states)
196
+ if not output_all_encoded_layers:
197
+ all_encoder_layers.append(hidden_states)
198
+ return all_encoder_layers
199
+
200
+
201
+ class BertEmbeddings(nn.Module):
202
+ """Construct the embeddings from word, position and token_type embeddings.
203
+ """
204
+
205
+ def __init__(self, config):
206
+ super(BertEmbeddings, self).__init__()
207
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
208
+
209
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
210
+ # any TensorFlow checkpoint file
211
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
212
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
213
+
214
+ def forward(self, input_ids, token_type_ids=None):
215
+ seq_length = input_ids.size(1)
216
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
217
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids[:, :, 0])
218
+
219
+ position_embeddings = self.position_embeddings(position_ids)
220
+
221
+ embeddings = input_ids + position_embeddings
222
+ # embeddings = input_ids
223
+ embeddings = self.LayerNorm(embeddings)
224
+ embeddings = self.dropout(embeddings)
225
+ return embeddings
226
+
227
+
228
+ class PositionalEncoding(nn.Module):
229
+ def __init__(self, config):
230
+ super(PositionalEncoding, self).__init__()
231
+ emb_dim = config.hidden_size
232
+ max_len = config.max_position_embeddings
233
+ self.position_enc = self.position_encoding_init(max_len, emb_dim)
234
+
235
+ @staticmethod
236
+ def position_encoding_init(n_position, emb_dim):
237
+ ''' Init the sinusoid position encoding table '''
238
+
239
+ # keep dim 0 for padding token position encoding zero vector
240
+ position_enc = np.array([
241
+ [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)]
242
+ if pos != 0 else np.zeros(emb_dim) for pos in range(n_position)])
243
+
244
+ position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # apply sin on 0th,2nd,4th...emb_dim
245
+ position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # apply cos on 1st,3rd,5th...emb_dim
246
+ return torch.from_numpy(position_enc).type(torch.FloatTensor)
247
+
248
+ def forward(self, word_seq):
249
+ position_encoding = self.position_enc.unsqueeze(0).expand_as(word_seq)
250
+ position_encoding = position_encoding.to(word_seq.device)
251
+ word_pos_encoded = word_seq + position_encoding
252
+ return word_pos_encoded
253
+
254
+
255
+ class BertPooler(nn.Module):
256
+ def __init__(self, config):
257
+ super(BertPooler, self).__init__()
258
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
259
+ self.activation = nn.Tanh()
260
+
261
+ def forward(self, hidden_states):
262
+ # We "pool" the model by simply taking the hidden state corresponding
263
+ # to the first token.
264
+ first_token_tensor = hidden_states[:, 0]
265
+ pooled_output = self.dense(first_token_tensor)
266
+ pooled_output = self.activation(pooled_output)
267
+ return pooled_output
code/data.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('..')
3
+
4
+ from torch.utils.data import Dataset
5
+ import pickle
6
+ import random
7
+ from . import LyricsCommentData
8
+
9
+ class LyricsCommentsDataset(Dataset):
10
+
11
+ def __init__(self, random=False):
12
+ super(LyricsCommentsDataset, self).__init__()
13
+ self.random = random
14
+ with open("dataset.pkl", "rb") as f:
15
+ self.data = pickle.load(f)
16
+
17
+ def __len__(self):
18
+ return len(self.data)
19
+
20
+ def __getitem__(self, item):
21
+ lyrics = self.data[item].lyrics
22
+ # if random:
23
+ # comment = random.choice(self.data[item].comments)
24
+ # else:
25
+ comment = self.data[item].comments[0]
26
+ # the longest?
27
+ for i, (tmp_item, _) in enumerate(self.data[item].comments):
28
+ if len(tmp_item) > len(comment[0]):
29
+ comment = self.data[item].comments[i]
30
+
31
+ comment = comment[0] # keep comments w/o rating
32
+
33
+ return [lyrics, comment]
34
+
35
+
36
+ class LyricsCommentsDatasetClean(Dataset):
37
+
38
+ def __init__(self, random=False):
39
+ super(LyricsCommentsDatasetClean, self).__init__()
40
+ self.random = random
41
+ with open("cleaned_dataset.pkl", "rb") as f:
42
+ self.data = pickle.load(f)
43
+
44
+ def __len__(self):
45
+ return len(self.data)
46
+
47
+ def __getitem__(self, item):
48
+ lyrics = self.data[item].lyrics
49
+ comment = self.data[item].comment
50
+
51
+ return [lyrics, comment]
52
+
53
+
54
+ class LyricsCommentsDatasetPsuedo(Dataset):
55
+
56
+ def __init__(self, dataset_path, random=False):
57
+ super(LyricsCommentsDatasetPsuedo, self).__init__()
58
+ self.random = random
59
+ with open(dataset_path, "rb") as f:
60
+ self.data = pickle.load(f)
61
+
62
+ def __len__(self):
63
+ return len(self.data)
64
+
65
+ def __getitem__(self, item):
66
+ lyrics = self.data[item].lyrics.replace('\n', ';')
67
+ comment = self.data[item].comment
68
+
69
+ return [lyrics, comment]
70
+
71
+
72
+ class LyricsCommentsDatasetPsuedo_fusion(Dataset):
73
+
74
+ def __init__(self, dataset_path):
75
+ super(LyricsCommentsDatasetPsuedo_fusion, self).__init__()
76
+ with open(dataset_path, "rb") as f:
77
+ self.data = pickle.load(f)
78
+
79
+ def __len__(self):
80
+ return len(self.data)
81
+
82
+
83
+ def __getitem__(self, item):
84
+ lyrics = self.data[item].lyrics.replace('\n', ';')
85
+ comment = self.data[item].comment
86
+ music_id = self.data[item].music4all_id
87
+
88
+ return [lyrics, comment, music_id]
89
+
90
+
91
+ from torch.utils.data import Dataset, DataLoader
92
+ import torch
93
+ from MusicData import MusicData
94
+ import csv
95
+ import os
96
+ from pydub import AudioSegment
97
+ import matplotlib.pyplot as plt
98
+ from scipy.io import wavfile
99
+ from tempfile import mktemp
100
+ from scipy import signal
101
+ import numpy as np
102
+ import torchaudio
103
+ import transformers
104
+ import nltk
105
+
106
+
107
+ class Music4AllDataset(Dataset):
108
+ def __init__(self,
109
+ mel_bins,
110
+ audio_length,
111
+ pad_length,
112
+ tag_file_path=r"Music4All/music4all/id_genres.csv",
113
+ augment=True):
114
+ self.tag_file_path = tag_file_path
115
+ self.allow_cache = True
116
+ self.mel_bins = mel_bins
117
+ self.audio_length = audio_length
118
+ self.pad_length = pad_length
119
+ self.augment = augment
120
+ # read all tags
121
+ tags_file = open(tag_file_path, 'r', encoding='utf-8')
122
+ self.tags_reader = list(csv.reader(tags_file, delimiter='\t'))[1:]
123
+ tags_file.close()
124
+ if self.augment:
125
+ self.data_augmentation()
126
+
127
+ def data_augmentation(self):
128
+ pass
129
+
130
+ def __len__(self):
131
+ return len(self.tags_reader)
132
+
133
+ def __getitem__(self, item):
134
+ """
135
+
136
+ :param item: index
137
+ :return: tags and mel-spectrogram.
138
+ """
139
+ id = self.tags_reader[item][0]
140
+ tags = self.tags_reader[item][1] #.split(',')
141
+
142
+ # pad tags
143
+ # if len(tags) >= self.pad_length:
144
+ # tags = tags[:self.pad_length]
145
+ # else:
146
+ # for i in range(self.pad_length - len(tags)):
147
+ # tags.append("[PAD]")
148
+
149
+ spec_path = os.path.join("Music4All/temp_data/specs/data_cache/", id + ".npy")
150
+ exist_cache = os.path.isfile(spec_path)
151
+ # search cache
152
+ # if exist cache, load
153
+ if self.allow_cache and exist_cache:
154
+ spectrogram = torch.Tensor(np.load(spec_path))
155
+ # if does not exist, calculate and save
156
+ else:
157
+ audio_path = os.path.join("Music4All/music4all/audios",
158
+ id + '.mp3'
159
+ )
160
+ (data, sample_rate) = torchaudio.backend.sox_io_backend.load(audio_path)
161
+ spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=self.mel_bins,
162
+ n_fft=512,
163
+ sample_rate=sample_rate,
164
+ f_max=8000.0,
165
+ f_min=0.0,
166
+ )(torch.Tensor(data))
167
+ # TODO: There is a huge bug!
168
+ # cut length
169
+ if self.audio_length is not None:
170
+ spectrogram = spectrogram[:, :, :self.audio_length]
171
+ # to mono
172
+ spectrogram = spectrogram[0, :, :].unsqueeze(0)
173
+
174
+ if self.allow_cache:
175
+ np.save(spec_path, spectrogram.numpy())
176
+
177
+ return tags, spectrogram
178
+
179
+
180
+ class MusCapsDataset(Dataset):
181
+ def __init__(self,
182
+ mel_bins,
183
+ audio_length,
184
+ pad_length,
185
+ tag_file_path=r"Music4All/music4all/id_genres.csv",
186
+ augment=True):
187
+ self.tag_file_path = tag_file_path
188
+ self.allow_cache = True
189
+ self.mel_bins = mel_bins
190
+ self.audio_length = audio_length
191
+ self.pad_length = pad_length
192
+ self.augment = augment
193
+ # read all tags
194
+ tags_file = open(tag_file_path, 'r', encoding='utf-8')
195
+ self.tags_reader = list(csv.reader(tags_file, delimiter='\t'))[1:]
196
+ tags_file.close()
197
+ if self.augment:
198
+ self.data_augmentation()
199
+
200
+ def data_augmentation(self):
201
+ pass
202
+
203
+ def __len__(self):
204
+ return len(self.tags_reader)
205
+
206
+ def __getitem__(self, item):
207
+ """
208
+
209
+ :param item: index
210
+ :return: tags and mel-spectrogram.
211
+ """
212
+ id = self.tags_reader[item][0]
213
+ tags = self.tags_reader[item][1] #.split(',')
214
+
215
+ # pad tags
216
+ # if len(tags) >= self.pad_length:
217
+ # tags = tags[:self.pad_length]
218
+ # else:
219
+ # for i in range(self.pad_length - len(tags)):
220
+ # tags.append("[PAD]")
221
+
222
+ spec_path = os.path.join("Music4All/temp_data/specs/data_cache/", id + ".npy")
223
+ exist_cache = os.path.isfile(spec_path)
224
+ # search cache
225
+ # if exist cache, load
226
+ if self.allow_cache and exist_cache:
227
+ spectrogram = torch.Tensor(np.load(spec_path))
228
+ # if does not exist, calculate and save
229
+ else:
230
+ audio_path = os.path.join("Music4All/music4all/audios",
231
+ id + '.mp3'
232
+ )
233
+ (data, sample_rate) = torchaudio.backend.sox_io_backend.load(audio_path)
234
+ spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=self.mel_bins,
235
+ n_fft=512,
236
+ sample_rate=sample_rate,
237
+ f_max=8000.0,
238
+ f_min=0.0,
239
+ )(torch.Tensor(data))
240
+ # cut length
241
+ if self.audio_length is not None:
242
+ spectrogram = spectrogram[:, :, :self.audio_length]
243
+ # to mono
244
+ spectrogram = spectrogram[0, :, :].unsqueeze(0)
245
+ np.save(spec_path, spectrogram.numpy())
246
+
247
+ return tags, spectrogram
248
+
249
+ class GTZANDataset(Dataset):
250
+ def __init__(self, raw_dataset, is_augment=True, window=1366):
251
+ self.raw = raw_dataset
252
+ self.data = list()
253
+ self.mel_bins = 96
254
+ self.gtzan_genres = [
255
+ "blues",
256
+ "classical",
257
+ "country",
258
+ "disco",
259
+ "hiphop",
260
+ "jazz",
261
+ "metal",
262
+ "pop",
263
+ "reggae",
264
+ "rock",
265
+ ]
266
+ self.is_augment = is_augment
267
+ self.window = window
268
+ self.init()
269
+
270
+ def init(self):
271
+ for i, (waveform, sample_rate, label) in enumerate(self.raw):
272
+ spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=self.mel_bins)(torch.Tensor(waveform))
273
+ if self.is_augment:
274
+ self.augment(spectrogram, label)
275
+ else:
276
+ self.data.append((spectrogram[:,:,:self.window], label))
277
+
278
+ def augment(self, spectrogram, label):
279
+ length = spectrogram.shape[-1] # length
280
+ # augment audio with sliding window
281
+ hop_length = 250
282
+ slices = (length - self.window) // hop_length
283
+ for i in range(slices):
284
+ self.data.append((spectrogram[:, :, i * hop_length:self.window + i*hop_length], label))
285
+
286
+
287
+
288
+ def __len__(self):
289
+ return len(self.data)
290
+
291
+ def __getitem__(self, index):
292
+ spectrogram, label = self.data[index]
293
+ label = self.gtzan_genres.index(label)
294
+ return spectrogram, label
295
+
296
+
297
+
code/eval.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data import LyricsCommentsDatasetPsuedo_fusion
3
+ from torch import utils, nn
4
+ from model import CommentGenerator
5
+ from model_fusion import CommentGenerator_fusion
6
+ import transformers
7
+ import datasets
8
+ from tqdm import tqdm
9
+ import statistics
10
+ import os
11
+ DATASET_PATH = "dataset_test.pkl"
12
+ MODEL_PATH = "model/bart_fusion_full.pt"
13
+ # MODEL_NAME = "bart"
14
+
15
+ os.environ["CUDA_VISIBLE_DEVICES"] = "4"
16
+
17
+ test_dataset = LyricsCommentsDatasetPsuedo_fusion(DATASET_PATH)
18
+ dataset_length = len(test_dataset)
19
+
20
+ test_dataloader = utils.data.DataLoader(test_dataset,
21
+ # batch_size=len(valid_dataset),
22
+ batch_size=32,
23
+ shuffle=False)
24
+
25
+ if 'baseline' in MODEL_PATH:
26
+ model = CommentGenerator().cuda()
27
+ else:
28
+ model = CommentGenerator_fusion().cuda()
29
+ model.load_state_dict(torch.load(MODEL_PATH))
30
+
31
+ model.eval()
32
+
33
+ samples_list = list()
34
+ # generate
35
+ for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
36
+ if 'baseline' in MODEL_PATH:
37
+ with torch.no_grad():
38
+ output_samples = model.generate(lyrics)
39
+ else:
40
+ with torch.no_grad():
41
+ output_samples = model.generate(lyrics, music_id)
42
+ samples_list.append(output_samples)
43
+
44
+ # ------ ROUGE ------ #
45
+
46
+ metrics = datasets.load_metric('rouge')#, 'sacrebleu', 'meteor', 'bertscore')
47
+
48
+ for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
49
+ output_samples = samples_list[batch_index]
50
+ metrics.add_batch(predictions=output_samples, references=comment)
51
+
52
+ score = metrics.compute()
53
+ print(score)
54
+
55
+ # ------ BLEU ------ #
56
+
57
+ metrics = datasets.load_metric('sacrebleu')#, 'sacrebleu', 'meteor', 'bertscore')
58
+
59
+ for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
60
+ output_samples = samples_list[batch_index]
61
+ metrics.add_batch(predictions=output_samples, references=[[i] for i in comment])
62
+
63
+ score = metrics.compute()
64
+ print(score)
65
+
66
+ # ------ BERTScore ------ #
67
+
68
+ metrics = datasets.load_metric('bertscore')#, 'sacrebleu', 'meteor', 'bertscore')
69
+
70
+ for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
71
+ output_samples = samples_list[batch_index]
72
+ metrics.add_batch(predictions=output_samples, references=[[i] for i in comment])
73
+
74
+ score = metrics.compute(lang='en')
75
+ score = statistics.mean(score['f1'])
76
+ print(score)
77
+
78
+ # ------ METEOR ------ #
79
+
80
+ metrics = datasets.load_metric('meteor')#, 'sacrebleu', 'meteor', 'bertscore')
81
+
82
+ for batch_index, [lyrics, comment, music_id] in enumerate(tqdm(test_dataloader)):
83
+ output_samples = samples_list[batch_index]
84
+ metrics.add_batch(predictions=output_samples, references=[[i] for i in comment])
85
+
86
+ score = metrics.compute()
87
+ print(score)
code/model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import BartTokenizer, BartForConditionalGeneration
4
+
5
+
6
+ class CommentGenerator(nn.Module):
7
+ def __init__(self):
8
+ super(CommentGenerator, self).__init__()
9
+ self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
10
+ self.bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
11
+ # self.bart_config = BartConfig()
12
+ self.condition = None
13
+
14
+
15
+ def forward(self, input_sentence_list, labels=None):
16
+ encoded_input = self.tokenizer(
17
+ input_sentence_list,
18
+ padding=True,
19
+ truncation=True,
20
+ max_length=512,
21
+ return_tensors='pt',
22
+ )
23
+ if labels is not None:
24
+ labels = self.tokenizer(
25
+ labels,
26
+ padding=True,
27
+ truncation=True,
28
+ max_length=512,
29
+ return_tensors='pt',
30
+ )
31
+ output = self.bart(input_ids=encoded_input['input_ids'].cuda(),
32
+ attention_mask=encoded_input['attention_mask'].cuda(),
33
+ labels=labels['input_ids'].cuda(),
34
+ # labels
35
+ )
36
+ return output
37
+
38
+ def generate(self, input_sentence_list, is_cuda=True):
39
+ encoded_input = self.tokenizer(input_sentence_list,
40
+ padding=True,
41
+ truncation=True,
42
+ return_tensors='pt',
43
+ )
44
+ output_ids = self.bart.generate(encoded_input['input_ids'].cuda(),
45
+ num_beams=4,
46
+ max_length=512,
47
+ early_stopping=True,
48
+ do_sample=True)
49
+ return ([self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
50
+ for g in output_ids])
51
+ # tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
52
+ # encoded_input = tokenizer(['Hello all', 'Hi all'], return_tensors='pt')
53
+ # print(encoded_input)
54
+
55
+
56
+
code/model_fusion.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import BartTokenizer
4
+ from modeling_bart import BartForMultimodalGeneration
5
+ from music_encoder import CNNSA
6
+
7
+
8
+
9
+ class CommentGenerator_fusion(nn.Module):
10
+ def __init__(self):
11
+ super(CommentGenerator_fusion, self).__init__()
12
+ self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
13
+ model_path = "best_model.pth"
14
+ self.music_encoder = CNNSA().cuda()
15
+ self.music_encoder.load_state_dict(torch.load(model_path))
16
+ # trial: fix music encoder's params
17
+ for params in self.music_encoder.parameters():
18
+ params.requires_grad = False
19
+
20
+ self.bart = BartForMultimodalGeneration.from_pretrained("facebook/bart-base",
21
+ fusion_layers=[4,5], # [4,5]
22
+ use_forget_gate=False, # [True]
23
+ dim_common=768, # 256
24
+ n_attn_heads=1).cuda()
25
+
26
+
27
+ def forward(self, input_sentence_list, music_ids, labels=None):
28
+ encoded_input = self.tokenizer(
29
+ input_sentence_list,
30
+ padding=True,
31
+ truncation=True,
32
+ max_length=512,
33
+ return_tensors='pt',
34
+ )
35
+ if labels is not None:
36
+ labels = self.tokenizer(
37
+ labels,
38
+ padding=True,
39
+ truncation=True,
40
+ max_length=512,
41
+ return_tensors='pt',
42
+ )
43
+ music_features = self.music_encoder(music_ids)
44
+ output = self.bart(input_ids=encoded_input['input_ids'].cuda(),
45
+ attention_mask=encoded_input['attention_mask'].cuda(),
46
+ labels=labels['input_ids'].cuda(),
47
+ music_features=music_features
48
+ # labels
49
+ )
50
+ return output
51
+
52
+ def generate(self, input_sentence_list, music_ids, is_cuda=True):
53
+ encoded_input = self.tokenizer(input_sentence_list,
54
+ padding=True,
55
+ truncation=True,
56
+ return_tensors='pt',
57
+ )
58
+ music_features = self.music_encoder(music_ids)
59
+ output_ids = self.bart.generate(encoded_input['input_ids'].cuda(),
60
+ num_beams=5,
61
+ max_length=512,
62
+ early_stopping=True,
63
+ do_sample=True,
64
+ music_features=music_features)
65
+ return ([self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
66
+ for g in output_ids])
67
+ # tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
68
+ # encoded_input = tokenizer(['Hello all', 'Hi all'], return_tensors='pt')
69
+ # print(encoded_input)
code/modeling_bart.py ADDED
@@ -0,0 +1,1483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # Revised by anonymous.
4
+
5
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """ PyTorch BART model. """
19
+ import copy
20
+ import math
21
+ import random
22
+ import warnings
23
+ from typing import Optional, Tuple
24
+ import numpy as np
25
+
26
+ import torch.nn.functional as F
27
+ import torch
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.file_utils import (
34
+ add_code_sample_docstrings,
35
+ add_end_docstrings,
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ replace_return_docstrings,
39
+ )
40
+ from transformers.modeling_outputs import (
41
+ BaseModelOutput,
42
+ BaseModelOutputWithPastAndCrossAttentions,
43
+ CausalLMOutputWithCrossAttentions,
44
+ Seq2SeqLMOutput,
45
+ Seq2SeqModelOutput,
46
+ Seq2SeqQuestionAnsweringModelOutput,
47
+ Seq2SeqSequenceClassifierOutput,
48
+ )
49
+ from transformers.modeling_utils import PreTrainedModel
50
+ from transformers.utils import logging
51
+ from transformers.models.bart.configuration_bart import BartConfig
52
+
53
+ from music_encoder import CNNSA
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CHECKPOINT_FOR_DOC = "facebook/bart-large"
58
+ _CONFIG_FOR_DOC = "BartConfig"
59
+ _TOKENIZER_FOR_DOC = "BartTokenizer"
60
+
61
+
62
+ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "facebook/bart-large",
64
+ # See all BART models at https://huggingface.co/models?filter=bart
65
+ ]
66
+
67
+
68
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
69
+ """
70
+ Shift input ids one token to the right.
71
+ """
72
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
73
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
74
+ shifted_input_ids[:, 0] = decoder_start_token_id
75
+
76
+ if pad_token_id is None:
77
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
78
+ # replace possible -100 values in labels by `pad_token_id`
79
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
80
+
81
+ return shifted_input_ids
82
+
83
+
84
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
85
+ """
86
+ Make causal mask used for bi-directional self-attention.
87
+ """
88
+ bsz, tgt_len = input_ids_shape
89
+ mask = torch.full((tgt_len, tgt_len), float("-inf"))
90
+ mask_cond = torch.arange(mask.size(-1))
91
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
92
+ mask = mask.to(dtype)
93
+
94
+ if past_key_values_length > 0:
95
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
96
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
97
+
98
+
99
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
100
+ """
101
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
102
+ """
103
+ bsz, src_len = mask.size()
104
+ tgt_len = tgt_len if tgt_len is not None else src_len
105
+
106
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
107
+
108
+ inverted_mask = 1.0 - expanded_mask
109
+
110
+ return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
111
+
112
+
113
+ class BartLearnedPositionalEmbedding(nn.Embedding):
114
+ """
115
+ This module learns positional embeddings up to a fixed maximum size.
116
+ """
117
+
118
+ def __init__(self, num_embeddings: int, embedding_dim: int):
119
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
120
+ # and adjust num_embeddings appropriately. Other models don't have this hack
121
+ self.offset = 2
122
+ super().__init__(num_embeddings + self.offset, embedding_dim)
123
+
124
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
125
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
126
+ bsz, seq_len = input_ids_shape[:2]
127
+ positions = torch.arange(
128
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
129
+ )
130
+ return super().forward(positions + self.offset)
131
+
132
+
133
+ class BartAttention(nn.Module):
134
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
135
+
136
+ def __init__(
137
+ self,
138
+ embed_dim: int,
139
+ num_heads: int,
140
+ dropout: float = 0.0,
141
+ is_decoder: bool = False,
142
+ bias: bool = True,
143
+ ):
144
+ super().__init__()
145
+ self.embed_dim = embed_dim
146
+ self.num_heads = num_heads
147
+ self.dropout = dropout
148
+ self.head_dim = embed_dim // num_heads
149
+
150
+ if (self.head_dim * num_heads) != self.embed_dim:
151
+ raise ValueError(
152
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
153
+ f" and `num_heads`: {num_heads})."
154
+ )
155
+ self.scaling = self.head_dim ** -0.5
156
+ self.is_decoder = is_decoder
157
+
158
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
159
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
160
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
161
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
162
+
163
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
164
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
165
+
166
+ def forward(
167
+ self,
168
+ hidden_states: torch.Tensor,
169
+ key_value_states: Optional[torch.Tensor] = None,
170
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
171
+ attention_mask: Optional[torch.Tensor] = None,
172
+ layer_head_mask: Optional[torch.Tensor] = None,
173
+ output_attentions: bool = False,
174
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
175
+ """Input shape: Batch x Time x Channel"""
176
+
177
+ # if key_value_states are provided this layer is used as a cross-attention layer
178
+ # for the decoder
179
+ is_cross_attention = key_value_states is not None
180
+
181
+ bsz, tgt_len, _ = hidden_states.size()
182
+
183
+ # get query proj
184
+ query_states = self.q_proj(hidden_states) * self.scaling
185
+ # get key, value proj
186
+ if is_cross_attention and past_key_value is not None:
187
+ # reuse k,v, cross_attentions
188
+ key_states = past_key_value[0]
189
+ value_states = past_key_value[1]
190
+ elif is_cross_attention:
191
+ # cross_attentions
192
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
193
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
194
+ elif past_key_value is not None:
195
+ # reuse k, v, self_attention
196
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
197
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
198
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
199
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
200
+ else:
201
+ # self_attention
202
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
203
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
204
+
205
+ if self.is_decoder:
206
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
207
+ # Further calls to cross_attention layer can then reuse all cross-attention
208
+ # key/value_states (first "if" case)
209
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
210
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
211
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
212
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
213
+ past_key_value = (key_states, value_states)
214
+
215
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
216
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
217
+ key_states = key_states.view(*proj_shape)
218
+ value_states = value_states.view(*proj_shape)
219
+
220
+ src_len = key_states.size(1)
221
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
222
+
223
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
224
+ raise ValueError(
225
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
226
+ )
227
+
228
+ if attention_mask is not None:
229
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
230
+ raise ValueError(
231
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
232
+ )
233
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
234
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
235
+
236
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
237
+
238
+ if layer_head_mask is not None:
239
+ if layer_head_mask.size() != (self.num_heads,):
240
+ raise ValueError(
241
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
242
+ )
243
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
244
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
245
+
246
+ if output_attentions:
247
+ # this operation is a bit awkward, but it's required to
248
+ # make sure that attn_weights keeps its gradient.
249
+ # In order to do so, attn_weights have to be reshaped
250
+ # twice and have to be reused in the following
251
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
252
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
253
+ else:
254
+ attn_weights_reshaped = None
255
+
256
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
257
+
258
+ attn_output = torch.bmm(attn_probs, value_states)
259
+
260
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
261
+ raise ValueError(
262
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
263
+ )
264
+
265
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
266
+ attn_output = attn_output.transpose(1, 2)
267
+
268
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
269
+ # partitioned aross GPUs when using tensor-parallelism.
270
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
271
+
272
+ attn_output = self.out_proj(attn_output)
273
+
274
+ return attn_output, attn_weights_reshaped, past_key_value
275
+
276
+
277
+ class BartEncoderLayer(nn.Module):
278
+ def __init__(self, config: BartConfig):
279
+ super().__init__()
280
+ self.embed_dim = config.d_model
281
+ self.self_attn = BartAttention(
282
+ embed_dim=self.embed_dim,
283
+ num_heads=config.encoder_attention_heads,
284
+ dropout=config.attention_dropout,
285
+ )
286
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
287
+ self.dropout = config.dropout
288
+ self.activation_fn = ACT2FN[config.activation_function]
289
+ self.activation_dropout = config.activation_dropout
290
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
291
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
292
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
293
+
294
+ def forward(
295
+ self,
296
+ hidden_states: torch.Tensor,
297
+ attention_mask: torch.Tensor,
298
+ layer_head_mask: torch.Tensor,
299
+ output_attentions: bool = False,
300
+ ):
301
+ """
302
+ Args:
303
+ hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
304
+ attention_mask (`torch.FloatTensor`): attention mask of size
305
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
306
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
307
+ *(encoder_attention_heads,)*.
308
+ output_attentions (`bool`, *optional*):
309
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
310
+ returned tensors for more detail.
311
+ """
312
+ residual = hidden_states
313
+ hidden_states, attn_weights, _ = self.self_attn(
314
+ hidden_states=hidden_states,
315
+ attention_mask=attention_mask,
316
+ layer_head_mask=layer_head_mask,
317
+ output_attentions=output_attentions,
318
+ )
319
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
320
+ hidden_states = residual + hidden_states
321
+ hidden_states = self.self_attn_layer_norm(hidden_states)
322
+
323
+ residual = hidden_states
324
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
325
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
326
+ hidden_states = self.fc2(hidden_states)
327
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
328
+ hidden_states = residual + hidden_states
329
+ hidden_states = self.final_layer_norm(hidden_states)
330
+
331
+ if hidden_states.dtype == torch.float16 and (
332
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
333
+ ):
334
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
335
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
336
+
337
+ outputs = (hidden_states,)
338
+
339
+ if output_attentions:
340
+ outputs += (attn_weights,)
341
+
342
+ return outputs
343
+
344
+
345
+ class BartDecoderLayer(nn.Module):
346
+ def __init__(self, config: BartConfig):
347
+ super().__init__()
348
+ self.embed_dim = config.d_model
349
+
350
+ self.self_attn = BartAttention(
351
+ embed_dim=self.embed_dim,
352
+ num_heads=config.decoder_attention_heads,
353
+ dropout=config.attention_dropout,
354
+ is_decoder=True,
355
+ )
356
+ self.dropout = config.dropout
357
+ self.activation_fn = ACT2FN[config.activation_function]
358
+ self.activation_dropout = config.activation_dropout
359
+
360
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
361
+ self.encoder_attn = BartAttention(
362
+ self.embed_dim,
363
+ config.decoder_attention_heads,
364
+ dropout=config.attention_dropout,
365
+ is_decoder=True,
366
+ )
367
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
368
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
369
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
370
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
371
+
372
+ def forward(
373
+ self,
374
+ hidden_states: torch.Tensor,
375
+ attention_mask: Optional[torch.Tensor] = None,
376
+ encoder_hidden_states: Optional[torch.Tensor] = None,
377
+ encoder_attention_mask: Optional[torch.Tensor] = None,
378
+ layer_head_mask: Optional[torch.Tensor] = None,
379
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
380
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
381
+ output_attentions: Optional[bool] = False,
382
+ use_cache: Optional[bool] = True,
383
+ ):
384
+ """
385
+ Args:
386
+ hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
387
+ attention_mask (`torch.FloatTensor`): attention mask of size
388
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
389
+ encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
390
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
391
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
392
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
393
+ *(encoder_attention_heads,)*.
394
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
395
+ size *(decoder_attention_heads,)*.
396
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
397
+ output_attentions (`bool`, *optional*):
398
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
399
+ returned tensors for more detail.
400
+ """
401
+ residual = hidden_states
402
+
403
+ # Self Attention
404
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
405
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
406
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
407
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
408
+ hidden_states=hidden_states,
409
+ past_key_value=self_attn_past_key_value,
410
+ attention_mask=attention_mask,
411
+ layer_head_mask=layer_head_mask,
412
+ output_attentions=output_attentions,
413
+ )
414
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
415
+ hidden_states = residual + hidden_states
416
+ hidden_states = self.self_attn_layer_norm(hidden_states)
417
+
418
+ # Cross-Attention Block
419
+ cross_attn_present_key_value = None
420
+ cross_attn_weights = None
421
+ if encoder_hidden_states is not None:
422
+ residual = hidden_states
423
+
424
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
425
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
426
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
427
+ hidden_states=hidden_states,
428
+ key_value_states=encoder_hidden_states,
429
+ attention_mask=encoder_attention_mask,
430
+ layer_head_mask=cross_attn_layer_head_mask,
431
+ past_key_value=cross_attn_past_key_value,
432
+ output_attentions=output_attentions,
433
+ )
434
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
435
+ hidden_states = residual + hidden_states
436
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
437
+
438
+ # add cross-attn to positions 3,4 of present_key_value tuple
439
+ present_key_value = present_key_value + cross_attn_present_key_value
440
+
441
+ # Fully Connected
442
+ residual = hidden_states
443
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
444
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
445
+ hidden_states = self.fc2(hidden_states)
446
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
447
+ hidden_states = residual + hidden_states
448
+ hidden_states = self.final_layer_norm(hidden_states)
449
+
450
+ outputs = (hidden_states,)
451
+
452
+ if output_attentions:
453
+ outputs += (self_attn_weights, cross_attn_weights)
454
+
455
+ if use_cache:
456
+ outputs += (present_key_value,)
457
+
458
+ return outputs
459
+
460
+
461
+ class BartClassificationHead(nn.Module):
462
+ """Head for sentence-level classification tasks."""
463
+
464
+ def __init__(
465
+ self,
466
+ input_dim: int,
467
+ inner_dim: int,
468
+ num_classes: int,
469
+ pooler_dropout: float,
470
+ ):
471
+ super().__init__()
472
+ self.dense = nn.Linear(input_dim, inner_dim)
473
+ self.dropout = nn.Dropout(p=pooler_dropout)
474
+ self.out_proj = nn.Linear(inner_dim, num_classes)
475
+
476
+ def forward(self, hidden_states: torch.Tensor):
477
+ hidden_states = self.dropout(hidden_states)
478
+ hidden_states = self.dense(hidden_states)
479
+ hidden_states = torch.tanh(hidden_states)
480
+ hidden_states = self.dropout(hidden_states)
481
+ hidden_states = self.out_proj(hidden_states)
482
+ return hidden_states
483
+
484
+
485
+ class BartPretrainedModel(PreTrainedModel):
486
+ config_class = BartConfig
487
+ base_model_prefix = "model"
488
+ supports_gradient_checkpointing = True
489
+ _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
490
+
491
+ def _init_weights(self, module):
492
+ std = self.config.init_std
493
+ if isinstance(module, nn.Linear):
494
+ module.weight.data.normal_(mean=0.0, std=std)
495
+ if module.bias is not None:
496
+ module.bias.data.zero_()
497
+ elif isinstance(module, nn.Embedding):
498
+ module.weight.data.normal_(mean=0.0, std=std)
499
+ if module.padding_idx is not None:
500
+ module.weight.data[module.padding_idx].zero_()
501
+
502
+ def _set_gradient_checkpointing(self, module, value=False):
503
+ if isinstance(module, (BartDecoder, BartEncoder)):
504
+ module.gradient_checkpointing = value
505
+
506
+ @property
507
+ def dummy_inputs(self):
508
+ pad_token = self.config.pad_token_id
509
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
510
+ dummy_inputs = {
511
+ "attention_mask": input_ids.ne(pad_token),
512
+ "input_ids": input_ids,
513
+ }
514
+ return dummy_inputs
515
+
516
+
517
+ class PretrainedBartModel(BartPretrainedModel):
518
+ def __init_subclass__(self):
519
+ warnings.warn(
520
+ "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.",
521
+ FutureWarning,
522
+ )
523
+
524
+
525
+ BART_START_DOCSTRING = r"""
526
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic
527
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
528
+ pruning heads etc.)
529
+
530
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
531
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
532
+ general usage and behavior.
533
+
534
+ Parameters:
535
+ config ([`BartConfig`]):
536
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
537
+ load the weights associated with the model, only the configuration. Check out the
538
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
539
+ """
540
+
541
+ BART_GENERATION_EXAMPLE = r"""
542
+ Summarization example::
543
+
544
+ >>> from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
545
+
546
+ >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
547
+ >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
548
+
549
+ >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
550
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
551
+
552
+ >>> # Generate Summary
553
+ >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
554
+ >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
555
+
556
+ Mask filling example::
557
+
558
+ >>> from transformers import BartTokenizer, BartForConditionalGeneration
559
+ >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
560
+ >>> TXT = "My friends are <mask> but they eat too many carbs."
561
+
562
+ >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
563
+ >>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
564
+ >>> logits = model(input_ids).logits
565
+
566
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
567
+ >>> probs = logits[0, masked_index].softmax(dim=0)
568
+ >>> values, predictions = probs.topk(5)
569
+
570
+ >>> tokenizer.decode(predictions).split()
571
+ """
572
+
573
+ BART_INPUTS_DOCSTRING = r"""
574
+ Args:
575
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
576
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
577
+ it.
578
+
579
+ Indices can be obtained using [`BartTokenizer`]. See
580
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
581
+ details.
582
+
583
+ [What are input IDs?](../glossary#input-ids)
584
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
585
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
586
+
587
+ - 1 for tokens that are **not masked**,
588
+ - 0 for tokens that are **masked**.
589
+
590
+ [What are attention masks?](../glossary#attention-mask)
591
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
592
+ Indices of decoder input sequence tokens in the vocabulary.
593
+
594
+ Indices can be obtained using [`BartTokenizer`]. See
595
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
596
+ details.
597
+
598
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
599
+
600
+ Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
601
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
602
+ `past_key_values`).
603
+
604
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
605
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to
606
+ the right for denoising pre-training following the paper.
607
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
608
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will
609
+ also be used by default.
610
+
611
+ If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_inputs`] and
612
+ modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
613
+ information on the default strategy.
614
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
615
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
616
+
617
+ - 1 indicates the head is **not masked**,
618
+ - 0 indicates the head is **masked**.
619
+
620
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
621
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
622
+
623
+ - 1 indicates the head is **not masked**,
624
+ - 0 indicates the head is **masked**.
625
+
626
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
627
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
628
+
629
+ - 1 indicates the head is **not masked**,
630
+ - 0 indicates the head is **masked**.
631
+
632
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
633
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
634
+ `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`,
635
+ *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
636
+ cross-attention of the decoder.
637
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
638
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors
639
+ of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
640
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
641
+
642
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
643
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
644
+
645
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
646
+ (those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
647
+ instead of all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated
648
+ vectors than the model's internal embedding lookup matrix.
649
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
650
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
651
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds`
652
+ have to be input (see `past_key_values`). This is useful if you want more control over how to convert
653
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
654
+
655
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds`
656
+ takes the value of `inputs_embeds`.
657
+ use_cache (`bool`, *optional*):
658
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up
659
+ decoding (see `past_key_values`).
660
+ output_attentions (`bool`, *optional*):
661
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
662
+ tensors for more detail.
663
+ output_hidden_states (`bool`, *optional*):
664
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
665
+ more detail.
666
+ return_dict (`bool`, *optional*):
667
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
668
+ """
669
+
670
+
671
+ class BartEncoder(BartPretrainedModel):
672
+ """
673
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
674
+ [`BartEncoderLayer`].
675
+
676
+ Args:
677
+ config: BartConfig
678
+ embed_tokens (nn.Embedding): output embedding
679
+ """
680
+
681
+ def __init__(self, config: BartConfig,
682
+ embed_tokens: Optional[nn.Embedding] = None,
683
+ fusion_layers=[5], # 5 is the last layer
684
+ use_forget_gate=True,
685
+ dim_common=256,
686
+ n_attn_heads=1):
687
+ super().__init__(config)
688
+
689
+ self.dropout = config.dropout
690
+ self.layerdrop = config.encoder_layerdrop
691
+
692
+ embed_dim = config.d_model
693
+ self.padding_idx = config.pad_token_id
694
+ self.max_source_positions = config.max_position_embeddings
695
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
696
+
697
+ if embed_tokens is not None:
698
+ self.embed_tokens = embed_tokens
699
+ else:
700
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
701
+
702
+ self.embed_positions = BartLearnedPositionalEmbedding(
703
+ config.max_position_embeddings,
704
+ embed_dim,
705
+ )
706
+ self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
707
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
708
+
709
+ self.gradient_checkpointing = False
710
+
711
+ # ==================== Modification Starts ====================
712
+ # 1. params and variables
713
+ self.use_forget_gate = use_forget_gate
714
+ self.fusion_layers = fusion_layers
715
+ music_feature_dim = 256
716
+ text_feature_dim = embed_dim # 768
717
+
718
+ # 2. define attention
719
+ self._linear_1 = nn.Linear(music_feature_dim, dim_common) # K
720
+ self._linear_2 = nn.Linear(music_feature_dim, dim_common) # V
721
+ self._linear_3 = nn.Linear(text_feature_dim, dim_common) # Q
722
+ self._multi_head_attn = nn.MultiheadAttention(dim_common, n_attn_heads)
723
+ self._linear_4 = nn.Linear(text_feature_dim + dim_common, text_feature_dim) # TODO: it does not make sense
724
+ if use_forget_gate:
725
+ self.fg = nn.Linear(dim_common + text_feature_dim, dim_common)
726
+
727
+ # ==================== Modification Ends ====================
728
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
729
+ self.sigmoid = nn.Sigmoid()
730
+
731
+ # Initialize weights and apply final processing
732
+ self.post_init()
733
+
734
+ def get_input_embeddings(self):
735
+ return self.embed_tokens
736
+
737
+ def set_input_embeddings(self, value):
738
+ self.embed_tokens = value
739
+
740
+ def forward(
741
+ self,
742
+ input_ids=None,
743
+ attention_mask=None,
744
+ head_mask=None,
745
+ inputs_embeds=None,
746
+ output_attentions=None,
747
+ output_hidden_states=None,
748
+ return_dict=None,
749
+ music_features=None,
750
+ music_len=None
751
+ ):
752
+ r"""
753
+ Args:
754
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
755
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
756
+ provide it.
757
+
758
+ Indices can be obtained using [`BartTokenizer`]. See
759
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
760
+ for details.
761
+
762
+ [What are input IDs?](../glossary#input-ids)
763
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
764
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
765
+
766
+ - 1 for tokens that are **not masked**,
767
+ - 0 for tokens that are **masked**.
768
+
769
+ [What are attention masks?](../glossary#attention-mask)
770
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
771
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
772
+
773
+ - 1 indicates the head is **not masked**,
774
+ - 0 indicates the head is **masked**.
775
+
776
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
777
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded
778
+ representation. This is useful if you want more control over how to convert `input_ids` indices
779
+ into associated vectors than the model's internal embedding lookup matrix.
780
+ output_attentions (`bool`, *optional*):
781
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
782
+ returned tensors for more detail.
783
+ output_hidden_states (`bool`, *optional*):
784
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
785
+ for more detail.
786
+ return_dict (`bool`, *optional*):
787
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
788
+ """
789
+
790
+ # ==================== Modification Starts ====================
791
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
792
+ output_hidden_states = (
793
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
794
+ )
795
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
796
+
797
+ # retrieve input_ids and inputs_embeds
798
+ if input_ids is not None and inputs_embeds is not None:
799
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
800
+ elif input_ids is not None:
801
+ input_shape = input_ids.size()
802
+ input_ids = input_ids.view(-1, input_shape[-1])
803
+ elif inputs_embeds is not None:
804
+ input_shape = inputs_embeds.size()[:-1]
805
+ else:
806
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
807
+
808
+ if inputs_embeds is None:
809
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
810
+
811
+ embed_pos = self.embed_positions(input_shape)
812
+
813
+ hidden_states = inputs_embeds + embed_pos
814
+ hidden_states = self.layernorm_embedding(hidden_states)
815
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
816
+
817
+ # expand attention_mask
818
+ if attention_mask is not None:
819
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
820
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
821
+
822
+ encoder_states = () if output_hidden_states else None
823
+ all_attentions = () if output_attentions else None
824
+
825
+ # check if head_mask has a correct number of layers specified if desired
826
+ if head_mask is not None:
827
+ if head_mask.size()[0] != (len(self.layers)):
828
+ raise ValueError(
829
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
830
+ )
831
+
832
+ for idx, encoder_layer in enumerate(self.layers):
833
+ if output_hidden_states:
834
+ encoder_states = encoder_states + (hidden_states,)
835
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
836
+ dropout_probability = random.uniform(0, 1)
837
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
838
+ layer_outputs = (None, None)
839
+ else:
840
+ if self.gradient_checkpointing and self.training:
841
+
842
+ def create_custom_forward(module):
843
+ def custom_forward(*inputs):
844
+ return module(*inputs, output_attentions)
845
+
846
+ return custom_forward
847
+
848
+ layer_outputs = torch.utils.checkpoint.checkpoint(
849
+ create_custom_forward(encoder_layer),
850
+ hidden_states,
851
+ attention_mask,
852
+ (head_mask[idx] if head_mask is not None else None),
853
+ )
854
+ else:
855
+ layer_outputs = encoder_layer(
856
+ hidden_states,
857
+ attention_mask,
858
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
859
+ output_attentions=output_attentions,
860
+ )
861
+
862
+ hidden_states = layer_outputs[0]
863
+
864
+ # ==================== music-text fusion =====================
865
+
866
+ def forget_gate(music_features, text_features):
867
+ forget_mask = self.fg(torch.cat((music_features, text_features), 2))
868
+ forget_mask = self.sigmiod(forget_mask)
869
+ forget_mask = F.dropout(forget_mask, p=self.dropout, training=self.training)
870
+ music_features = forget_mask.mul(music_features)
871
+ return music_features
872
+
873
+ if idx in self.fusion_layers:
874
+ '''
875
+ => K_a = linear_1(V) in (S_v, D_a)
876
+ => V_a = linear_2(V) in (S_v, D_a)
877
+ => Q_a = linear_3(T) in (S_t, D_a)
878
+ => T_out = MultiHeadAttn(Q_a, K_a, V_a) in (S_t, D_a)
879
+ => T_out = linear_4(concat(T, T_out)) in (S_t, D_t)
880
+ => T_out = T + T_out (Residual Connection)
881
+ '''
882
+ K = self._linear_1(music_features).transpose(0, 1)
883
+ V = self._linear_2(music_features).transpose(0, 1)
884
+ Q = self._linear_3(hidden_states).transpose(0, 1)
885
+ attn_output, _ = self._multi_head_attn(Q, K, V)
886
+ attn_output = attn_output.transpose(0, 1)
887
+ if self.use_forget_gate:
888
+ forget_mask = self.fg(torch.cat((attn_output, hidden_states), 2))
889
+ forget_mask = self.sigmoid(forget_mask)
890
+ forget_mask = F.dropout(forget_mask, p=self.dropout, training=self.training)
891
+ attn_output = forget_mask.mul(attn_output)
892
+ # output = self._linear_4(torch.cat((hidden_states, attn_output), 2))
893
+
894
+ # Residual Connection
895
+ hidden_states = hidden_states + 0.1 * attn_output
896
+ hidden_states = self.final_layer_norm(hidden_states)
897
+
898
+ # ==================== music-text fusion =====================
899
+
900
+ if output_attentions:
901
+ all_attentions = all_attentions + (layer_outputs[1],)
902
+
903
+ if output_hidden_states:
904
+ encoder_states = encoder_states + (hidden_states,)
905
+
906
+ if not return_dict:
907
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
908
+ return BaseModelOutput(
909
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
910
+ )
911
+
912
+
913
+ class BartDecoder(BartPretrainedModel):
914
+ """
915
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
916
+
917
+ Args:
918
+ config: BartConfig
919
+ embed_tokens (nn.Embedding): output embedding
920
+ """
921
+
922
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
923
+ super().__init__(config)
924
+ self.dropout = config.dropout
925
+ self.layerdrop = config.decoder_layerdrop
926
+ self.padding_idx = config.pad_token_id
927
+ self.max_target_positions = config.max_position_embeddings
928
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
929
+
930
+ if embed_tokens is not None:
931
+ self.embed_tokens = embed_tokens
932
+ else:
933
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
934
+
935
+ self.embed_positions = BartLearnedPositionalEmbedding(
936
+ config.max_position_embeddings,
937
+ config.d_model,
938
+ )
939
+ self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
940
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
941
+
942
+ self.gradient_checkpointing = False
943
+ # Initialize weights and apply final processing
944
+ self.post_init()
945
+
946
+ def get_input_embeddings(self):
947
+ return self.embed_tokens
948
+
949
+ def set_input_embeddings(self, value):
950
+ self.embed_tokens = value
951
+
952
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
953
+ # create causal mask
954
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
955
+ combined_attention_mask = None
956
+ if input_shape[-1] > 1:
957
+ combined_attention_mask = _make_causal_mask(
958
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
959
+ ).to(self.device)
960
+
961
+ if attention_mask is not None:
962
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
963
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
964
+ combined_attention_mask = (
965
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
966
+ )
967
+
968
+ return combined_attention_mask
969
+
970
+ def forward(
971
+ self,
972
+ input_ids=None,
973
+ attention_mask=None,
974
+ encoder_hidden_states=None,
975
+ encoder_attention_mask=None,
976
+ head_mask=None,
977
+ cross_attn_head_mask=None,
978
+ past_key_values=None,
979
+ inputs_embeds=None,
980
+ use_cache=None,
981
+ output_attentions=None,
982
+ output_hidden_states=None,
983
+ return_dict=None,
984
+ ):
985
+ r"""
986
+ Args:
987
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
988
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
989
+ provide it.
990
+
991
+ Indices can be obtained using [`BartTokenizer`]. See
992
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
993
+ for details.
994
+
995
+ [What are input IDs?](../glossary#input-ids)
996
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
997
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
998
+
999
+ - 1 for tokens that are **not masked**,
1000
+ - 0 for tokens that are **masked**.
1001
+
1002
+ [What are attention masks?](../glossary#attention-mask)
1003
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
1004
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1005
+ of the decoder.
1006
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
1007
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
1008
+ selected in `[0, 1]`:
1009
+
1010
+ - 1 for tokens that are **not masked**,
1011
+ - 0 for tokens that are **masked**.
1012
+
1013
+ [What are attention masks?](../glossary#attention-mask)
1014
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1015
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1016
+
1017
+ - 1 indicates the head is **not masked**,
1018
+ - 0 indicates the head is **masked**.
1019
+
1020
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1021
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
1022
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
1023
+
1024
+ - 1 indicates the head is **not masked**,
1025
+ - 0 indicates the head is **masked**.
1026
+
1027
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1028
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2
1029
+ tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
1030
+ tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1031
+
1032
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1033
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential
1034
+ decoding.
1035
+
1036
+ If `past_key_values` are used, the user can optionally input only the last
1037
+ `decoder_input_ids` (those that don't have their past key value states given to this model) of
1038
+ shape `(batch_size, 1)` instead of all ``decoder_input_ids``` of shape `(batch_size,
1039
+ sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices
1040
+ into associated vectors than the model's internal embedding lookup matrix.
1041
+ output_attentions (`bool`, *optional*):
1042
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1043
+ returned tensors for more detail.
1044
+ output_hidden_states (`bool`, *optional*):
1045
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1046
+ for more detail.
1047
+ return_dict (`bool`, *optional*):
1048
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1049
+ """
1050
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1051
+ output_hidden_states = (
1052
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1053
+ )
1054
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1055
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1056
+
1057
+ # retrieve input_ids and inputs_embeds
1058
+ if input_ids is not None and inputs_embeds is not None:
1059
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1060
+ elif input_ids is not None:
1061
+ input_shape = input_ids.size()
1062
+ input_ids = input_ids.view(-1, input_shape[-1])
1063
+ elif inputs_embeds is not None:
1064
+ input_shape = inputs_embeds.size()[:-1]
1065
+ else:
1066
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1067
+
1068
+ # past_key_values_length
1069
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1070
+
1071
+ if inputs_embeds is None:
1072
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1073
+
1074
+ attention_mask = self._prepare_decoder_attention_mask(
1075
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1076
+ )
1077
+
1078
+ # expand encoder attention mask
1079
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1080
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1081
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1082
+
1083
+ # embed positions
1084
+ positions = self.embed_positions(input_shape, past_key_values_length)
1085
+
1086
+ hidden_states = inputs_embeds + positions
1087
+ hidden_states = self.layernorm_embedding(hidden_states)
1088
+
1089
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1090
+
1091
+ # decoder layers
1092
+ all_hidden_states = () if output_hidden_states else None
1093
+ all_self_attns = () if output_attentions else None
1094
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1095
+ next_decoder_cache = () if use_cache else None
1096
+
1097
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1098
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
1099
+ if attn_mask is not None:
1100
+ if attn_mask.size()[0] != (len(self.layers)):
1101
+ raise ValueError(
1102
+ "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
1103
+ )
1104
+
1105
+ for idx, decoder_layer in enumerate(self.layers):
1106
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1107
+ if output_hidden_states:
1108
+ all_hidden_states += (hidden_states,)
1109
+ dropout_probability = random.uniform(0, 1)
1110
+ if self.training and (dropout_probability < self.layerdrop):
1111
+ continue
1112
+
1113
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1114
+
1115
+ if self.gradient_checkpointing and self.training:
1116
+
1117
+ if use_cache:
1118
+ logger.warning(
1119
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1120
+ )
1121
+ use_cache = False
1122
+
1123
+ def create_custom_forward(module):
1124
+ def custom_forward(*inputs):
1125
+ # None for past_key_value
1126
+ return module(*inputs, output_attentions, use_cache)
1127
+
1128
+ return custom_forward
1129
+
1130
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1131
+ create_custom_forward(decoder_layer),
1132
+ hidden_states,
1133
+ attention_mask,
1134
+ encoder_hidden_states,
1135
+ encoder_attention_mask,
1136
+ head_mask[idx] if head_mask is not None else None,
1137
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
1138
+ None,
1139
+ )
1140
+ else:
1141
+
1142
+ layer_outputs = decoder_layer(
1143
+ hidden_states,
1144
+ attention_mask=attention_mask,
1145
+ encoder_hidden_states=encoder_hidden_states,
1146
+ encoder_attention_mask=encoder_attention_mask,
1147
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1148
+ cross_attn_layer_head_mask=(
1149
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1150
+ ),
1151
+ past_key_value=past_key_value,
1152
+ output_attentions=output_attentions,
1153
+ use_cache=use_cache,
1154
+ )
1155
+ hidden_states = layer_outputs[0]
1156
+
1157
+ if use_cache:
1158
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1159
+
1160
+ if output_attentions:
1161
+ all_self_attns += (layer_outputs[1],)
1162
+
1163
+ if encoder_hidden_states is not None:
1164
+ all_cross_attentions += (layer_outputs[2],)
1165
+
1166
+ # add hidden states from the last decoder layer
1167
+ if output_hidden_states:
1168
+ all_hidden_states += (hidden_states,)
1169
+
1170
+ next_cache = next_decoder_cache if use_cache else None
1171
+ if not return_dict:
1172
+ return tuple(
1173
+ v
1174
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1175
+ if v is not None
1176
+ )
1177
+ return BaseModelOutputWithPastAndCrossAttentions(
1178
+ last_hidden_state=hidden_states,
1179
+ past_key_values=next_cache,
1180
+ hidden_states=all_hidden_states,
1181
+ attentions=all_self_attns,
1182
+ cross_attentions=all_cross_attentions,
1183
+ )
1184
+
1185
+
1186
+ @add_start_docstrings(
1187
+ "The bare BART Model outputting raw hidden-states without any specific head on top.",
1188
+ BART_START_DOCSTRING,
1189
+ )
1190
+ class BartModel(BartPretrainedModel):
1191
+ def __init__(self, config: BartConfig,
1192
+ fusion_layers=None,
1193
+ use_forget_gate=None,
1194
+ dim_common=256,
1195
+ n_attn_heads=1):
1196
+ super().__init__(config)
1197
+
1198
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1199
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
1200
+
1201
+ self.encoder = BartEncoder(config, self.shared, fusion_layers, use_forget_gate, dim_common, n_attn_heads)
1202
+ self.decoder = BartDecoder(config, self.shared)
1203
+
1204
+ # Initialize weights and apply final processing
1205
+ self.post_init()
1206
+
1207
+ def get_input_embeddings(self):
1208
+ return self.shared
1209
+
1210
+ def set_input_embeddings(self, value):
1211
+ self.shared = value
1212
+ self.encoder.embed_tokens = self.shared
1213
+ self.decoder.embed_tokens = self.shared
1214
+
1215
+ def get_encoder(self):
1216
+ return self.encoder
1217
+
1218
+ def get_decoder(self):
1219
+ return self.decoder
1220
+
1221
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1222
+ @add_code_sample_docstrings(
1223
+ processor_class=_TOKENIZER_FOR_DOC,
1224
+ checkpoint=_CHECKPOINT_FOR_DOC,
1225
+ output_type=Seq2SeqModelOutput,
1226
+ config_class=_CONFIG_FOR_DOC,
1227
+ )
1228
+ def forward(
1229
+ self,
1230
+ input_ids=None,
1231
+ attention_mask=None,
1232
+ decoder_input_ids=None,
1233
+ decoder_attention_mask=None,
1234
+ head_mask=None,
1235
+ decoder_head_mask=None,
1236
+ cross_attn_head_mask=None,
1237
+ encoder_outputs=None,
1238
+ past_key_values=None,
1239
+ inputs_embeds=None,
1240
+ decoder_inputs_embeds=None,
1241
+ use_cache=None,
1242
+ output_attentions=None,
1243
+ output_hidden_states=None,
1244
+ return_dict=None,
1245
+ music_features=None,
1246
+ music_len=None,
1247
+ ):
1248
+
1249
+ # different to other models, Bart automatically creates decoder_input_ids from
1250
+ # input_ids if no decoder_input_ids are provided
1251
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1252
+ if input_ids is None:
1253
+ raise ValueError(
1254
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
1255
+ "passed, `input_ids` cannot be `None`. Please pass either "
1256
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
1257
+ )
1258
+
1259
+ decoder_input_ids = shift_tokens_right(
1260
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
1261
+ )
1262
+
1263
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1264
+ output_hidden_states = (
1265
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1266
+ )
1267
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1268
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1269
+
1270
+ if encoder_outputs is None:
1271
+ encoder_outputs = self.encoder(
1272
+ input_ids=input_ids,
1273
+ attention_mask=attention_mask,
1274
+ head_mask=head_mask,
1275
+ inputs_embeds=inputs_embeds,
1276
+ output_attentions=output_attentions,
1277
+ output_hidden_states=output_hidden_states,
1278
+ return_dict=return_dict,
1279
+ music_features=music_features,
1280
+ music_len=music_len,
1281
+ )
1282
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1283
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1284
+ encoder_outputs = BaseModelOutput(
1285
+ last_hidden_state=encoder_outputs[0],
1286
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1287
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1288
+ )
1289
+
1290
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1291
+ decoder_outputs = self.decoder(
1292
+ input_ids=decoder_input_ids,
1293
+ attention_mask=decoder_attention_mask,
1294
+ encoder_hidden_states=encoder_outputs[0],
1295
+ encoder_attention_mask=attention_mask,
1296
+ head_mask=decoder_head_mask,
1297
+ cross_attn_head_mask=cross_attn_head_mask,
1298
+ past_key_values=past_key_values,
1299
+ inputs_embeds=decoder_inputs_embeds,
1300
+ use_cache=use_cache,
1301
+ output_attentions=output_attentions,
1302
+ output_hidden_states=output_hidden_states,
1303
+ return_dict=return_dict,
1304
+ )
1305
+
1306
+ if not return_dict:
1307
+ return decoder_outputs + encoder_outputs
1308
+
1309
+ return Seq2SeqModelOutput(
1310
+ last_hidden_state=decoder_outputs.last_hidden_state,
1311
+ past_key_values=decoder_outputs.past_key_values,
1312
+ decoder_hidden_states=decoder_outputs.hidden_states,
1313
+ decoder_attentions=decoder_outputs.attentions,
1314
+ cross_attentions=decoder_outputs.cross_attentions,
1315
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1316
+ encoder_hidden_states=encoder_outputs.hidden_states,
1317
+ encoder_attentions=encoder_outputs.attentions,
1318
+ )
1319
+
1320
+
1321
+ @add_start_docstrings(
1322
+ "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
1323
+ )
1324
+ class BartForMultimodalGeneration(BartPretrainedModel):
1325
+ base_model_prefix = "model"
1326
+ _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1327
+
1328
+ def __init__(self, config: BartConfig, fusion_layers=None, use_forget_gate=None, dim_common=256, n_attn_heads=1):
1329
+ super().__init__(config)
1330
+ self.model = BartModel(config, fusion_layers, use_forget_gate, dim_common, n_attn_heads)
1331
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1332
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
1333
+
1334
+ # Initialize weights and apply final processing
1335
+ self.post_init()
1336
+
1337
+ def get_encoder(self):
1338
+ return self.model.get_encoder()
1339
+
1340
+ def get_decoder(self):
1341
+ return self.model.get_decoder()
1342
+
1343
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
1344
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
1345
+ self._resize_final_logits_bias(new_num_tokens)
1346
+ return new_embeddings
1347
+
1348
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
1349
+ old_num_tokens = self.final_logits_bias.shape[-1]
1350
+ if new_num_tokens <= old_num_tokens:
1351
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
1352
+ else:
1353
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1354
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1355
+ self.register_buffer("final_logits_bias", new_bias)
1356
+
1357
+ def get_output_embeddings(self):
1358
+ return self.lm_head
1359
+
1360
+ def set_output_embeddings(self, new_embeddings):
1361
+ self.lm_head = new_embeddings
1362
+
1363
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1364
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1365
+ @add_end_docstrings(BART_GENERATION_EXAMPLE)
1366
+ def forward(
1367
+ self,
1368
+ input_ids=None,
1369
+ attention_mask=None,
1370
+ decoder_input_ids=None,
1371
+ decoder_attention_mask=None,
1372
+ head_mask=None,
1373
+ decoder_head_mask=None,
1374
+ cross_attn_head_mask=None,
1375
+ encoder_outputs=None,
1376
+ past_key_values=None,
1377
+ inputs_embeds=None,
1378
+ decoder_inputs_embeds=None,
1379
+ labels=None,
1380
+ use_cache=None,
1381
+ output_attentions=None,
1382
+ output_hidden_states=None,
1383
+ return_dict=None,
1384
+ music_features=None,
1385
+ music_len=None,
1386
+ ):
1387
+ r"""
1388
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1389
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1390
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1391
+
1392
+ Returns:
1393
+ """
1394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1395
+
1396
+ if labels is not None:
1397
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1398
+ decoder_input_ids = shift_tokens_right(
1399
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1400
+ )
1401
+
1402
+ outputs = self.model(
1403
+ input_ids,
1404
+ attention_mask=attention_mask,
1405
+ decoder_input_ids=decoder_input_ids,
1406
+ encoder_outputs=encoder_outputs,
1407
+ decoder_attention_mask=decoder_attention_mask,
1408
+ head_mask=head_mask,
1409
+ decoder_head_mask=decoder_head_mask,
1410
+ cross_attn_head_mask=cross_attn_head_mask,
1411
+ past_key_values=past_key_values,
1412
+ inputs_embeds=inputs_embeds,
1413
+ decoder_inputs_embeds=decoder_inputs_embeds,
1414
+ use_cache=use_cache,
1415
+ output_attentions=output_attentions,
1416
+ output_hidden_states=output_hidden_states,
1417
+ return_dict=return_dict,
1418
+ music_features=music_features,
1419
+ music_len=music_len,
1420
+ )
1421
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1422
+
1423
+ masked_lm_loss = None
1424
+ if labels is not None:
1425
+ loss_fct = CrossEntropyLoss()
1426
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1427
+
1428
+ if not return_dict:
1429
+ output = (lm_logits,) + outputs[1:]
1430
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1431
+
1432
+ return Seq2SeqLMOutput(
1433
+ loss=masked_lm_loss,
1434
+ logits=lm_logits,
1435
+ past_key_values=outputs.past_key_values,
1436
+ decoder_hidden_states=outputs.decoder_hidden_states,
1437
+ decoder_attentions=outputs.decoder_attentions,
1438
+ cross_attentions=outputs.cross_attentions,
1439
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1440
+ encoder_hidden_states=outputs.encoder_hidden_states,
1441
+ encoder_attentions=outputs.encoder_attentions,
1442
+ )
1443
+
1444
+ def prepare_inputs_for_generation(
1445
+ self,
1446
+ decoder_input_ids,
1447
+ past=None,
1448
+ attention_mask=None,
1449
+ head_mask=None,
1450
+ decoder_head_mask=None,
1451
+ cross_attn_head_mask=None,
1452
+ use_cache=None,
1453
+ encoder_outputs=None,
1454
+ **kwargs
1455
+ ):
1456
+ # cut decoder_input_ids if past is used
1457
+ if past is not None:
1458
+ decoder_input_ids = decoder_input_ids[:, -1:]
1459
+
1460
+ return {
1461
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1462
+ "encoder_outputs": encoder_outputs,
1463
+ "past_key_values": past,
1464
+ "decoder_input_ids": decoder_input_ids,
1465
+ "attention_mask": attention_mask,
1466
+ "head_mask": head_mask,
1467
+ "decoder_head_mask": decoder_head_mask,
1468
+ "cross_attn_head_mask": cross_attn_head_mask,
1469
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1470
+ }
1471
+
1472
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1473
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1474
+
1475
+ @staticmethod
1476
+ def _reorder_cache(past, beam_idx):
1477
+ reordered_past = ()
1478
+ for layer_past in past:
1479
+ # cached cross_attention states don't have to be reordered -> they are always the same
1480
+ reordered_past += (
1481
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1482
+ )
1483
+ return reordered_past
code/music_encoder.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio
5
+ import os
6
+ import random
7
+
8
+ from attention_modules import BertConfig, BertEncoder, BertPooler
9
+
10
+
11
+ class Conv_1d(nn.Module):
12
+ def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2):
13
+ super(Conv_1d, self).__init__()
14
+ self.conv = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
15
+ self.bn = nn.BatchNorm1d(output_channels)
16
+ self.relu = nn.ReLU()
17
+ self.mp = nn.MaxPool1d(pooling)
18
+ def forward(self, x):
19
+ out = self.mp(self.relu(self.bn(self.conv(x))))
20
+ return out
21
+
22
+
23
+ class Conv_2d(nn.Module):
24
+ def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2):
25
+ super(Conv_2d, self).__init__()
26
+ self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
27
+ self.bn = nn.BatchNorm2d(output_channels)
28
+ self.relu = nn.ReLU()
29
+ self.mp = nn.MaxPool2d(pooling)
30
+ def forward(self, x):
31
+ out = self.mp(self.relu(self.bn(self.conv(x))))
32
+ return out
33
+
34
+
35
+ class Res_2d(nn.Module):
36
+ def __init__(self, input_channels, output_channels, shape=3, stride=2):
37
+ super(Res_2d, self).__init__()
38
+ # convolution
39
+ self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
40
+ self.bn_1 = nn.BatchNorm2d(output_channels)
41
+ self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2)
42
+ self.bn_2 = nn.BatchNorm2d(output_channels)
43
+
44
+ # residual
45
+ self.diff = False
46
+ if (stride != 1) or (input_channels != output_channels):
47
+ self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
48
+ self.bn_3 = nn.BatchNorm2d(output_channels)
49
+ self.diff = True
50
+ self.relu = nn.ReLU()
51
+
52
+ def forward(self, x):
53
+ # convolution
54
+ out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
55
+
56
+ # residual
57
+ if self.diff:
58
+ x = self.bn_3(self.conv_3(x))
59
+ out = x + out
60
+ out = self.relu(out)
61
+ return out
62
+
63
+
64
+ class CNNSA(nn.Module):
65
+ '''
66
+ Won et al. 2019
67
+ Toward interpretable music tagging with self-attention.
68
+ Feature extraction with CNN + temporal summary with Transformer encoder.
69
+ '''
70
+ def __init__(self,
71
+ n_channels=128,
72
+ sample_rate=16000,
73
+ n_fft=512,
74
+ f_min=0.0,
75
+ f_max=8000.0,
76
+ n_mels=128,
77
+ n_class=50):
78
+ super(CNNSA, self).__init__()
79
+
80
+ # Spectrogram
81
+ self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
82
+ n_fft=n_fft,
83
+ f_min=f_min,
84
+ f_max=f_max,
85
+ n_mels=n_mels)
86
+ self.to_db = torchaudio.transforms.AmplitudeToDB()
87
+ self.spec_bn = nn.BatchNorm2d(1)
88
+
89
+ # CNN
90
+ self.layer1 = Res_2d(1, n_channels, stride=2)
91
+ self.layer2 = Res_2d(n_channels, n_channels, stride=2)
92
+ self.layer3 = Res_2d(n_channels, n_channels*2, stride=2)
93
+ self.layer4 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
94
+ self.layer5 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
95
+ self.layer6 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
96
+ self.layer7 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
97
+
98
+ # Transformer encoder
99
+ bert_config = BertConfig(vocab_size=256,
100
+ hidden_size=256,
101
+ num_hidden_layers=2,
102
+ num_attention_heads=8,
103
+ intermediate_size=1024,
104
+ hidden_act="gelu",
105
+ hidden_dropout_prob=0.4,
106
+ max_position_embeddings=700,
107
+ attention_probs_dropout_prob=0.5)
108
+ self.encoder = BertEncoder(bert_config)
109
+ self.pooler = BertPooler(bert_config)
110
+ self.vec_cls = self.get_cls(256)
111
+
112
+ # Dense
113
+ self.dropout = nn.Dropout(0.5)
114
+ self.dense = nn.Linear(256, n_class)
115
+
116
+ def get_cls(self, channel):
117
+ np.random.seed(0)
118
+ single_cls = torch.Tensor(np.random.random((1, channel)))
119
+ vec_cls = torch.cat([single_cls for _ in range(64)], dim=0)
120
+ vec_cls = vec_cls.unsqueeze(1)
121
+ return vec_cls
122
+
123
+ def append_cls(self, x):
124
+ batch, _, _ = x.size()
125
+ part_vec_cls = self.vec_cls[:batch].clone()
126
+ part_vec_cls = part_vec_cls.to(x.device)
127
+ return torch.cat([part_vec_cls, x], dim=1)
128
+
129
+ def get_spec(self, ids, audio_length=15*16000, allow_random=False):
130
+
131
+ wav_list = list()
132
+
133
+ for id in ids:
134
+ audio_path = os.path.join("/import/c4dm-datasets/Music4All/music4all/audios", id + '.mp3')
135
+ (wav, sample_rate) = torchaudio.backend.sox_io_backend.load(audio_path)
136
+
137
+ # to mono
138
+ mono_wav = torch.mean(wav, dim=0)
139
+
140
+ # cut length
141
+ if allow_random:
142
+ random_index = random.randint(0, len(mono_wav) - audio_length - 1)
143
+ else:
144
+ random_index = 0
145
+ mono_wav_cut = mono_wav[random_index: random_index + audio_length]
146
+
147
+ wav_list.append(mono_wav_cut)
148
+
149
+ # merge wav to (bs, length)
150
+ data = torch.stack(wav_list, dim=0)
151
+
152
+ # to spectrogram
153
+ spectrogram = self.spec(data.cuda())
154
+
155
+ return spectrogram
156
+
157
+ def forward(self, ids):
158
+ # Spectrogram
159
+ # for batch
160
+ spec = self.get_spec(ids)
161
+ spec_db = self.to_db(spec)
162
+ x = spec_db.unsqueeze(1) # add channel dim
163
+ x = self.spec_bn(x)
164
+
165
+ # CNN
166
+ x = self.layer1(x)
167
+ x = self.layer2(x)
168
+ x = self.layer3(x)
169
+ x = self.layer4(x)
170
+ x = self.layer5(x)
171
+ x = self.layer6(x)
172
+ x = self.layer7(x)
173
+ x = x.squeeze(2)
174
+
175
+ # Get [CLS] token
176
+ x = x.permute(0, 2, 1)
177
+ x = self.append_cls(x)
178
+
179
+ # Transformer encoder
180
+ x = self.encoder(x)
181
+ x = x[-1] # last layer
182
+ # x = self.pooler(x)
183
+ #
184
+ # # Dense
185
+ # x = self.dropout(x)
186
+ # x = self.dense(x)
187
+ # x = nn.Sigmoid()(x)
188
+
189
+ return x # return the last layer. Shape: (length, 256)
190
+
191
+
192
+ # test code
193
+ # model = CNNSA()
194
+ # model.load_state_dict(torch.load("best_model.pth"))
195
+ # id = ["wlIcjSZkgW0cgWrm", "wlIcjSZkgW0cgWrm"]
196
+ # output = model(id)
code/train.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from data import LyricsCommentsDatasetPsuedo
4
+ from torch import utils, nn
5
+ from model import CommentGenerator
6
+ import transformers
7
+ import time
8
+ import statistics
9
+ import os
10
+ import random
11
+ import datasets
12
+
13
+ IS_LOAD = False
14
+ LOAD_EPOCH = 0
15
+ EPOCH = 20
16
+ BATCH_SIZE = 8
17
+ LOG_INTERVAL = 100
18
+ SAMPLE_INTERVAL = 2000
19
+ VALIDATION_INTERVAL = 2
20
+ LOG_FOLDER = "log/"
21
+ MODEL_FOLDER = "model/"
22
+ EARLY_STOPPING_INTERVAL = 5
23
+ MODEL_NAME = "bart_baseline_full_256"
24
+ CHOICE_NUMBER = 5
25
+ DATASET_PATH = "dataset_not_negative_256.pkl"
26
+
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = "4"
28
+
29
+ dataset = LyricsCommentsDatasetPsuedo(dataset_path=DATASET_PATH)
30
+ dataset_length = len(dataset)
31
+
32
+ train_dataset_length = int(dataset_length * 0.9)
33
+ valid_dataset_length = dataset_length - train_dataset_length
34
+ train_dataset, valid_dataset = utils.data.random_split(dataset,
35
+ [train_dataset_length,
36
+ valid_dataset_length],
37
+ generator=torch.Generator().manual_seed(42))
38
+ train_dataloader = utils.data.DataLoader(train_dataset,
39
+ batch_size=BATCH_SIZE,
40
+ shuffle=True)
41
+ valid_dataloader = utils.data.DataLoader(valid_dataset,
42
+ batch_size=32,
43
+ shuffle=False)
44
+
45
+ model = CommentGenerator().cuda()
46
+
47
+ criterion = nn.CrossEntropyLoss()
48
+
49
+ optimizer = transformers.Adafactor(model.parameters(), warmup_init=False, relative_step=False,
50
+ lr=6e-4,
51
+ )
52
+
53
+ loss_stat = list()
54
+ start_time = time.time()
55
+ start_time_local = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
56
+
57
+ early_stop_token = (0.0, 0)
58
+
59
+ model.train()
60
+ for epoch in range(1 + LOAD_EPOCH, EPOCH + 1 + LOAD_EPOCH):
61
+ for batch_index, [lyrics, comment] in enumerate(train_dataloader):
62
+ # pre-process data
63
+ input_sentences = lyrics
64
+ raw_labels = comment
65
+ output = model(input_sentences, raw_labels)
66
+ loss = output.loss
67
+
68
+ optimizer.zero_grad()
69
+ loss.backward()
70
+ optimizer.step()
71
+ loss_stat.append(loss.item())
72
+
73
+ # log
74
+ if batch_index and batch_index % LOG_INTERVAL == 0:
75
+ curr_time = time.time()
76
+ passed_time_all = curr_time - start_time
77
+ time_str = f"{int(passed_time_all / 60)}:{int(passed_time_all % 60)}"
78
+ log = f"{MODEL_NAME}\t" \
79
+ f"Time: {time_str}\t" \
80
+ f"Epoch {epoch}: {batch_index}/{int(len(train_dataloader.dataset) / BATCH_SIZE)}\t" \
81
+ f"Loss: {statistics.mean(loss_stat[-1 * BATCH_SIZE:])}\t" \
82
+ f"Avg loss: {statistics.mean(loss_stat)}"
83
+ if __debug__:
84
+ print(log)
85
+ with open(os.path.join(LOG_FOLDER, MODEL_NAME + "_" + start_time_local + ".txt"), 'a+', encoding='utf-8') as r:
86
+ r.write(log)
87
+ r.write("\n")
88
+ loss_stat = list()
89
+
90
+ if batch_index and batch_index % SAMPLE_INTERVAL == 0:
91
+
92
+ model.eval()
93
+ samples_list = random.choices(valid_dataset, k=CHOICE_NUMBER)
94
+ sample_sentence, sample_label = zip(*samples_list)
95
+ output_samples = model.generate(sample_sentence)
96
+ for sample_index in range(CHOICE_NUMBER):
97
+ log = f"Lyrics: {sample_sentence[sample_index]}\n" \
98
+ f"Sample outputs: {output_samples[sample_index]}\n" \
99
+ f"Ground Truth: {sample_label[sample_index]}"
100
+ if __debug__:
101
+ print(log)
102
+ with open(os.path.join(LOG_FOLDER, MODEL_NAME + "_" + start_time_local + ".txt"), 'a+', encoding='utf-8') as r:
103
+ r.write(log)
104
+ r.write("\n")
105
+ model.train()
106
+
107
+ if epoch and epoch % VALIDATION_INTERVAL == 0:
108
+ model.eval()
109
+ metrics = datasets.load_metric('rouge')
110
+ valid_dataloader = utils.data.DataLoader(valid_dataset,
111
+ batch_size=32,
112
+ shuffle=False)
113
+ for batch_index_valid, [lyrics_valid, comment_valid] in enumerate(valid_dataloader):
114
+ output_samples = model.generate(lyrics_valid)
115
+ metrics.add_batch(predictions=output_samples, references=comment_valid)
116
+
117
+ # control time.
118
+ if batch_index_valid > 10:
119
+ break
120
+ score = metrics.compute()
121
+ if __debug__:
122
+ print(str(score))
123
+ with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+',
124
+ encoding='utf-8') as r:
125
+ r.write(str(score))
126
+ r.write("\n")
127
+
128
+ # save
129
+ if score['rouge1'].mid.recall > early_stop_token[0]:
130
+ early_stop_token = [score['rouge1'].mid.recall, epoch] # replace to the best
131
+ torch.save(model.state_dict(), os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_best.pt"))
132
+ torch.save(optimizer.state_dict(),
133
+ os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_optim_best.pt"))
134
+
135
+ if epoch:
136
+ torch.save(model.state_dict(), os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_epoch{epoch}.pt"))
137
+ torch.save(optimizer.state_dict(),
138
+ os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_optim_epoch{epoch}.pt"))
139
+
140
+ # early stopping
141
+ if score['rouge1'].mid.recall <= early_stop_token[0] and epoch > (
142
+ early_stop_token[1] + EARLY_STOPPING_INTERVAL):
143
+ print(f"Early Stopping. Best Score: {early_stop_token[0]} at Epoch {early_stop_token[1]}.")
144
+
145
+ model.train()
code/train_fusion.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from data import LyricsCommentsDatasetPsuedo_fusion
4
+ from torch import utils, nn
5
+ from model_fusion import CommentGenerator_fusion
6
+ import transformers
7
+ import time
8
+ import statistics
9
+ import os
10
+ import random
11
+ import datasets
12
+
13
+ IS_LOAD = False
14
+ LOAD_EPOCH = 0
15
+ EPOCH = 50
16
+ BATCH_SIZE = 8
17
+ LOG_INTERVAL = 100
18
+ SAMPLE_INTERVAL = 1000
19
+ VALIDATION_INTERVAL = 2
20
+ LOG_FOLDER = "log/"
21
+ MODEL_FOLDER = "model/"
22
+ SAVE_INTERVAL = 2
23
+ EARLY_STOPPING_INTERVAL = 5
24
+ MODEL_NAME = "bart_fusion_full_256"
25
+ CHOICE_NUMBER = 2
26
+ DATASET_PATH = "/homes/yz007/multimodal-transformer/comment_generator/dataset_full_256.pkl"
27
+
28
+ os.environ["CUDA_VISIBLE_DEVICES"] = "2"
29
+
30
+ dataset = LyricsCommentsDatasetPsuedo_fusion(dataset_path=DATASET_PATH)
31
+ dataset_length = len(dataset)
32
+
33
+ train_dataset_length = int(dataset_length * 0.9)
34
+ valid_dataset_length = dataset_length - train_dataset_length
35
+ train_dataset, valid_dataset = utils.data.random_split(dataset,
36
+ [train_dataset_length,
37
+ valid_dataset_length],
38
+ generator=torch.Generator().manual_seed(42))
39
+ train_dataloader = utils.data.DataLoader(train_dataset,
40
+ batch_size=BATCH_SIZE,
41
+ shuffle=True)
42
+ # valid_dataloader = utils.data.DataLoader(valid_dataset,
43
+ # batch_size=32,
44
+ # shuffle=False)
45
+
46
+ model = CommentGenerator_fusion().cuda()
47
+
48
+ criterion = nn.CrossEntropyLoss()
49
+
50
+
51
+
52
+ # optimizer = transformers.Adafactor(filter(lambda p: p.requires_grad, model.parameters()),
53
+ # lr=6e-4,
54
+ # )
55
+ optimizer = transformers.Adafactor(model.parameters(), warmup_init=False, relative_step=False,
56
+ lr=6e-4,
57
+ )
58
+
59
+ if IS_LOAD:
60
+ model.load_state_dict(torch.load("/homes/yz007/multimodal-transformer/comment_generator/model/bart_fusion_positive_256_6e-4_epoch6.pt"))
61
+ optimizer.load_state_dict(torch.load("/homes/yz007/multimodal-transformer/comment_generator/model/bart_fusion_positive_256_6e-4_optim_epoch6.pt"))
62
+
63
+ loss_stat = list()
64
+ start_time = time.time()
65
+ start_time_local = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
66
+
67
+ early_stop_token = [0.0, 0]
68
+ validation_loss_history = list()
69
+
70
+ model.train()
71
+ for epoch in range(1 + LOAD_EPOCH, EPOCH + 1 + LOAD_EPOCH):
72
+ for batch_index, [lyrics, comment, music_id] in enumerate(train_dataloader):
73
+ # pre-process data
74
+ input_sentences = lyrics
75
+ raw_labels = comment
76
+ output = model(input_sentences, music_id, raw_labels)
77
+ loss = output.loss
78
+
79
+ optimizer.zero_grad()
80
+ loss.backward()
81
+ optimizer.step()
82
+ loss_stat.append(loss.item())
83
+
84
+ # log
85
+ if batch_index and batch_index % LOG_INTERVAL == 0:
86
+ curr_time = time.time()
87
+ passed_time_all = curr_time - start_time
88
+ time_str = f"{int(passed_time_all / 60)}:{int(passed_time_all % 60)}"
89
+ log = f"{MODEL_NAME}\t" \
90
+ f"Time: {time_str}\t" \
91
+ f"Epoch {epoch}: {batch_index}/{int(len(train_dataloader.dataset) / BATCH_SIZE)}\t" \
92
+ f"Loss: {statistics.mean(loss_stat[-1 * LOG_INTERVAL * BATCH_SIZE:])}\t" \
93
+ f"Avg loss: {statistics.mean(loss_stat)}"
94
+ if __debug__:
95
+ print(log)
96
+ with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+',
97
+ encoding='utf-8') as r:
98
+ r.write(log)
99
+ r.write("\n")
100
+ loss_stat = list()
101
+
102
+ if batch_index and batch_index % SAMPLE_INTERVAL == 0:
103
+ # make samples
104
+ model.eval()
105
+ samples_list = random.choices(valid_dataset, k=CHOICE_NUMBER)
106
+ sample_sentence, sample_label, music_ids = zip(*samples_list)
107
+ with torch.no_grad():
108
+ output_samples = model.generate(sample_sentence, music_ids)
109
+ for sample_index in range(CHOICE_NUMBER):
110
+ log = f"Lyrics: {sample_sentence[sample_index]}\n" \
111
+ f"Sample outputs: {output_samples[sample_index]}\n" \
112
+ f"Ground Truth: {sample_label[sample_index]}"
113
+ if __debug__:
114
+ print(log)
115
+ with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+',
116
+ encoding='utf-8') as r:
117
+ r.write(log)
118
+ r.write("\n")
119
+
120
+ # validation loss
121
+ valid_dataloader = utils.data.DataLoader(valid_dataset,
122
+ batch_size=8,
123
+ shuffle=False)
124
+ valid_loss_stat = list()
125
+ for batch_index_valid, [lyrics_valid, comment_valid, music_id_valid] in enumerate(valid_dataloader):
126
+ with torch.no_grad():
127
+ output_valid = model(lyrics_valid, music_id_valid, comment_valid)
128
+ valid_loss = output_valid.loss.item()
129
+ valid_loss_stat.append(valid_loss)
130
+ if batch_index_valid > 15:
131
+ break
132
+ valid_loss_mean = statistics.mean(valid_loss_stat)
133
+ validation_loss_history.append(valid_loss_mean)
134
+ log = f"{MODEL_NAME}\t" \
135
+ f"Time: {time_str}\t" \
136
+ f"Epoch {epoch}: {batch_index}/{int(len(train_dataloader.dataset) / BATCH_SIZE)}\t" \
137
+ f"Validation Loss: {valid_loss_mean}\t"
138
+ if __debug__:
139
+ print(log)
140
+ with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+',
141
+ encoding='utf-8') as r:
142
+ r.write(log)
143
+ r.write("\n")
144
+
145
+ # back to train
146
+ model.train()
147
+
148
+ if epoch and epoch % VALIDATION_INTERVAL == 0:
149
+ model.eval()
150
+ metrics = datasets.load_metric('rouge')
151
+ valid_dataloader = utils.data.DataLoader(valid_dataset,
152
+ batch_size=8,
153
+ shuffle=False)
154
+ for batch_index_valid, [lyrics_valid, comment_valid, music_id_valid] in enumerate(valid_dataloader):
155
+ with torch.no_grad():
156
+ output_samples = model.generate(lyrics_valid, music_id_valid)
157
+ metrics.add_batch(predictions=output_samples, references=comment_valid)
158
+ # control time.
159
+ if batch_index_valid > 10:
160
+ break
161
+ score = metrics.compute()
162
+ if __debug__:
163
+ print(str(score))
164
+ with open(os.path.join(LOG_FOLDER, MODEL_NAME + '_' + start_time_local + ".txt"), 'a+',
165
+ encoding='utf-8') as r:
166
+ r.write(str(score))
167
+ r.write("\n")
168
+
169
+ # save
170
+ if score['rouge1'].mid.recall > early_stop_token[0]:
171
+ early_stop_token = [score['rouge1'].mid.recall, epoch] # replace to the best
172
+ torch.save(model.state_dict(), os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_best.pt"))
173
+ torch.save(optimizer.state_dict(),
174
+ os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_optim_best.pt"))
175
+
176
+ # save
177
+ if epoch and epoch % SAVE_INTERVAL == 0:
178
+ torch.save(model.state_dict(), os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_epoch{epoch}.pt"))
179
+ torch.save(optimizer.state_dict(),
180
+ os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_optim_epoch{epoch}.pt"))
181
+
182
+ # early stopping
183
+ if len(validation_loss_history) > EARLY_STOPPING_INTERVAL:
184
+ if min(validation_loss_history[-2 * EARLY_STOPPING_INTERVAL:]) == validation_loss_history[-2 * EARLY_STOPPING_INTERVAL]:
185
+ print(f"Early Stopping. Best Score: {early_stop_token[0]} at Epoch {early_stop_token[1]}.")
186
+ break
187
+ if score['rouge1'].mid.recall <= early_stop_token[0] and epoch > (
188
+ early_stop_token[1] + EARLY_STOPPING_INTERVAL):
189
+ print(f"Early Stopping. Best Score: {early_stop_token[0]} at Epoch {early_stop_token[1]}.")
190
+ break
191
+ model.train()
192
+
193
+ print(f"Training Complete. Best Score: {early_stop_token[0]} at Epoch {early_stop_token[1]}.")