fffiloni commited on
Commit
5898616
·
verified ·
1 Parent(s): 6935ada

Create model_zero.py

Browse files
Files changed (1) hide show
  1. model_zero.py +251 -0
model_zero.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2023) Tsinghua University, Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import soundfile as sf
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from peft import LoraConfig, TaskType, get_peft_model
21
+ from transformers import (
22
+ WhisperFeatureExtractor,
23
+ WhisperModel,
24
+ LlamaForCausalLM,
25
+ LlamaTokenizer
26
+ )
27
+ import librosa
28
+ from beats.BEATs import BEATsConfig, BEATs
29
+ from qformer.Qformer import BertConfig, BertLMHeadModel
30
+
31
+ class SALMONN(nn.Module):
32
+ def __init__(
33
+ self,
34
+ ckpt,
35
+ whisper_path,
36
+ beats_path,
37
+ vicuna_path,
38
+ speech_qformer_token_num=1,
39
+ speech_qformer_layer=2,
40
+ lora=True,
41
+ lora_alpha=32,
42
+ lora_rank=8,
43
+ lora_dropout=0.1,
44
+ second_per_frame=0.333333,
45
+ second_stride=0.333333,
46
+ low_resource=False
47
+ ):
48
+
49
+ super().__init__()
50
+
51
+ # feature_extractor
52
+ self.feature_extractor = WhisperFeatureExtractor.from_pretrained(whisper_path)
53
+
54
+ # whisper
55
+ self.speech_encoder = WhisperModel.from_pretrained(whisper_path).encoder
56
+ self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model)
57
+
58
+ # beats
59
+ self.beats_ckpt = beats_path
60
+ beats_checkpoint = torch.load(self.beats_ckpt, map_location='cpu')
61
+ beats_cfg = BEATsConfig(beats_checkpoint['cfg'])
62
+ beats = BEATs(beats_cfg)
63
+ beats.load_state_dict(beats_checkpoint['model'])
64
+ self.beats = beats
65
+ self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim)
66
+ for name, param in self.beats.named_parameters():
67
+ param.requires_grad = False
68
+ self.beats.eval()
69
+
70
+ # init speech Qformer
71
+ self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer(
72
+ speech_qformer_token_num,
73
+ self.speech_encoder.config.d_model + self.beats.cfg.encoder_embed_dim,
74
+ speech_qformer_layer,
75
+ )
76
+ self.second_per_frame = second_per_frame
77
+ self.second_stride = second_stride
78
+
79
+ # vicuna
80
+ if not low_resource:
81
+ self.llama_model = LlamaForCausalLM.from_pretrained(
82
+ vicuna_path,
83
+ torch_dtype=torch.float16,
84
+ )
85
+ else:
86
+ self.llama_model = LlamaForCausalLM.from_pretrained(
87
+ vicuna_path,
88
+ torch_dtype=torch.float16,
89
+ load_in_8bit=True,
90
+ device_map={'': 0}
91
+ )
92
+
93
+ # lora
94
+ self.lora = lora
95
+ if lora:
96
+ target_modules = None
97
+ self.peft_config = LoraConfig(
98
+ task_type=TaskType.CAUSAL_LM,
99
+ inference_mode=True,
100
+ r=lora_rank,
101
+ lora_alpha=lora_alpha,
102
+ lora_dropout=lora_dropout,
103
+ target_modules=target_modules,
104
+ )
105
+ self.llama_model = get_peft_model(self.llama_model, self.peft_config)
106
+
107
+ # tokenizer
108
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_path, use_fast=False)
109
+ self.llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
110
+ self.llama_tokenizer.padding_side = "right"
111
+
112
+ # proj
113
+ self.speech_llama_proj = nn.Linear(
114
+ self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size)
115
+
116
+ # load ckpt
117
+ ckpt_dict = torch.load(ckpt, map_location=device)['model']
118
+ self.load_state_dict(ckpt_dict, strict=False)
119
+
120
+ def generate(
121
+ self,
122
+ wav_path,
123
+ prompt,
124
+ prompt_pattern="USER: <Speech><SpeechHere></Speech> {}\nASSISTANT:",
125
+ device='cuda:0',
126
+ max_length=150,
127
+ num_beams=4,
128
+ do_sample=True,
129
+ min_length=1,
130
+ top_p=0.9,
131
+ repetition_penalty=1.0,
132
+ length_penalty=1.0,
133
+ temperature=1.0,
134
+ ):
135
+ # read wav
136
+ wav, sr = sf.read(wav_path)
137
+ if len(wav.shape) == 2:
138
+ wav = wav[:, 0]
139
+ if len(wav) > 30 * sr:
140
+ wav = wav[: 30 * sr]
141
+ if sr != 16000:
142
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=16000, res_type="fft")
143
+
144
+ # whisper
145
+ spectrogram = self.feature_extractor(wav, return_tensors="pt", sampling_rate=16000).input_features.to(device) # [1, 80, 3000]
146
+ speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state
147
+
148
+ # beats
149
+ raw_wav = torch.from_numpy(wav).to(device).unsqueeze(0)
150
+ audio_padding_mask = torch.zeros(raw_wav.shape, device=device).bool()
151
+ audio_embeds, _ = self.beats.extract_features(raw_wav, padding_mask=audio_padding_mask, feature_only=True)
152
+
153
+ # auditory embeds
154
+ speech_embeds = self.ln_speech(speech_embeds)
155
+ audio_embeds = self.ln_audio(audio_embeds)
156
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
157
+ speech_embeds = torch.cat([speech_embeds, audio_embeds], dim=-1)
158
+
159
+ # split frames
160
+ B, T, C = speech_embeds.shape
161
+ kernel = round(T * self.second_per_frame / 30.0)
162
+ stride = round(T * self.second_stride / 30.0)
163
+ kernel = (1, kernel)
164
+ stride = (1, stride)
165
+ speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2)
166
+ speech_embeds_overlap = F.unfold(speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride)
167
+ _, _, L = speech_embeds_overlap.shape
168
+ speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L)
169
+ speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1])
170
+ speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C)
171
+ speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device)
172
+
173
+ # Qformer
174
+ query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1)
175
+ query_output = self.speech_Qformer.bert(
176
+ query_embeds=query_tokens,
177
+ encoder_hidden_states=speech_embeds,
178
+ encoder_attention_mask=speech_atts,
179
+ return_dict=True,
180
+ )
181
+ speech_embeds = self.speech_llama_proj(query_output.last_hidden_state)
182
+ speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous()
183
+ speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device)
184
+
185
+ # USER: <Speech>speech_embeds<Speech> prompt\nASSISTANT:
186
+ embed_tokens = self.llama_model.model.model.embed_tokens if self.lora else self.llama_model.model.embed_tokens
187
+ prompt_left, prompts_right = prompt_pattern.format(prompt).split('<SpeechHere>')
188
+ prompt_left_ids = self.llama_tokenizer(
189
+ prompt_left,
190
+ return_tensors="pt",
191
+ add_special_tokens=False
192
+ ).to(speech_embeds.device).input_ids
193
+ prompt_left_embeds = embed_tokens(prompt_left_ids)
194
+ prompt_right_ids = self.llama_tokenizer(
195
+ prompts_right,
196
+ return_tensors="pt",
197
+ add_special_tokens=False
198
+ ).to(speech_embeds.device).input_ids
199
+ prompt_right_embeds = embed_tokens(prompt_right_ids)
200
+
201
+ bos_embeds = self.llama_model.model.embed_tokens(
202
+ torch.ones(
203
+ [1, 1],
204
+ dtype=torch.long,
205
+ device=device,
206
+ ) * self.llama_tokenizer.bos_token_id
207
+ ) if not self.lora else self.llama_model.model.model.embed_tokens(
208
+ torch.ones(
209
+ [1, 1],
210
+ dtype=torch.long,
211
+ device=device,
212
+ ) * self.llama_tokenizer.bos_token_id
213
+ )
214
+
215
+ embeds = torch.cat([bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds], dim=1)
216
+ atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
217
+
218
+ # generate
219
+ output = self.llama_model.generate(
220
+ inputs_embeds=embeds,
221
+ max_length=max_length,
222
+ num_beams=num_beams,
223
+ do_sample=do_sample,
224
+ min_length=min_length,
225
+ top_p=top_p,
226
+ repetition_penalty=repetition_penalty,
227
+ length_penalty=length_penalty,
228
+ temperature=temperature,
229
+ attention_mask=atts,
230
+ bos_token_id=self.llama_tokenizer.bos_token_id,
231
+ eos_token_id=self.llama_tokenizer.eos_token_id,
232
+ pad_token_id=self.llama_tokenizer.pad_token_id
233
+ )
234
+
235
+ output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
236
+
237
+ return output_text
238
+
239
+ def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2):
240
+ encoder_config = BertConfig()
241
+ encoder_config.num_hidden_layers = num_hidden_layers
242
+ encoder_config.encoder_width = speech_width
243
+ encoder_config.add_cross_attention = True
244
+ encoder_config.cross_attention_freq = 1
245
+ encoder_config.query_length = num_query_token
246
+ Qformer = BertLMHeadModel(config=encoder_config)
247
+ query_tokens = nn.Parameter(
248
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
249
+ )
250
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
251
+ return Qformer, query_tokens