maxoul commited on
Commit
ac4eef4
·
verified ·
1 Parent(s): 9e5c9e5

Upload PISCO

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. modelling_pisco.py +78 -59
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "/scratch/1/user/mlouis/calmar/pisco_hub_models/mistral_with_mistral_labels",
3
  "architectures": [
4
  "PISCO"
5
  ],
 
1
  {
2
+ "_name_or_path": "/scratch/1/user/mlouis/calmar/pisco_hub_models/pisco-mistral",
3
  "architectures": [
4
  "PISCO"
5
  ],
modelling_pisco.py CHANGED
@@ -108,12 +108,29 @@ class PISCO(PreTrainedModel):
108
  def compress(self, enc_input_ids, enc_attention_mask):
109
  return self.compr_decoder(enc_input_ids, enc_attention_mask)
110
 
111
- def replace_emb(self, enc_input_ids, compressed_embs, dec_input_ids):
112
  """
113
- Compression logic (either with decoder or with dedicated compressor)
114
  """
115
- indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
116
- input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  return input_embeds
118
 
119
  def compr_decoder(self, input_ids, attention_mask):
@@ -126,24 +143,16 @@ class PISCO(PreTrainedModel):
126
  if 'encoder_adapter' in self.adapter_keys:
127
  self.decoder.set_adapter('encoder_adapter')
128
 
129
- print(self.decoder.device, input_ids.device, attention_mask.device)
130
-
131
  emb = self.decoder(input_ids=input_ids,
132
  attention_mask=attention_mask,
133
  output_hidden_states=True).hidden_states[-1]
134
  mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device))
135
  return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
136
 
137
- def prepare_encoder_inputs_to_decoder(self, texts, max_length, q_texts=None):
138
- if q_texts is not None:
139
- texts_to_encode = [self.tokenizer.enc_token + self.tokenizer.bos_token + '\nQuery:\n' + query + 'Document:\n' + text + self.tokenizer.eos_token
140
- for text, query in zip(texts, q_texts)]
141
- inp_enc = self.tokenizer(texts_to_encode, return_tensors='pt', padding='max_length', max_length=max_length + 8, truncation=True, add_special_tokens=False)
142
- else:
143
- inp_enc = [self.tokenizer.enc_token + self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in texts]
144
- inp_enc = self.tokenizer(inp_enc, return_tensors='pt', padding="longest", max_length=max_length+3, truncation=True, add_special_tokens=False)
145
-
146
- num_mem_tokens = 128 // self.compr_rate # maybe change that
147
  assert num_mem_tokens == len(self.tokenizer.mem_tokens)
148
  inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'],
149
  inp_enc['attention_mask'],
@@ -155,28 +164,6 @@ class PISCO(PreTrainedModel):
155
  def prepare_encoder_inputs(self, texts, max_length):
156
  return self.prepare_encoder_inputs_to_decoder(texts, max_length)
157
 
158
- def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
159
- """
160
- Replace memory tokens in the decoder input to with the compressed embeddings
161
- """
162
- inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
163
- num_embs = compressed_embs.size(1)
164
- if self.sep:
165
- slot_len = num_embs + 1
166
- else:
167
- slot_len = num_embs
168
- # get first mem_token indices
169
- first_mem_token_indices = torch.argmax((dec_input_ids == self.tokenizer.mem_token_ids[0]).int(), dim=1)
170
- batch_size = inputs_embeds.size(0)
171
- # for each example in batch, replace them with compressed embeddings
172
- for i in range(batch_size):
173
- for j in range(indices[i], indices[i + 1]):
174
- start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
175
- assert inputs_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \
176
- f"{inputs_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}"
177
- inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
178
- return inputs_embeds
179
-
180
  def forward(self,
181
  enc_input_ids: torch.LongTensor = None,
182
  enc_attention_mask: torch.LongTensor = None,
@@ -204,7 +191,7 @@ class PISCO(PreTrainedModel):
204
 
205
  # Perform compression with gradient tracking
206
  compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
207
- inputs_embeds = self.replace_emb(enc_input_ids, compressed_embs, dec_input_ids)
208
 
209
  # decoding
210
  if 'decoder_adapter' in self.adapter_keys:
@@ -218,42 +205,80 @@ class PISCO(PreTrainedModel):
218
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
219
 
220
  def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
221
- # TODO: test
 
 
 
 
222
  self.generation_top_k = len(documents[0])
223
  assert len(documents) == len(questions)
224
  assert all([len(context) == len(documents[0]) for context in documents])
225
  flat_documents = sum(documents, [])
226
 
227
  model_input = {}
 
 
228
  input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
229
  device = self.decoder.device
230
-
231
  model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device)
232
 
 
233
  instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
234
-
235
  inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
236
-
237
  model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
238
 
 
239
  return self.generate(model_input, max_new_tokens=max_new_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  def compress_documents(self, documents: list[str]) -> torch.Tensor:
242
- # TODO: test
 
 
243
  input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
244
  enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
245
  attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
246
- print('yo', self.decoder.device, enc_input_ids.device, attention_mask.device)
247
  return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
248
 
249
  def generate(self, model_input, max_new_tokens=128):
 
 
 
250
 
251
  enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
252
 
253
- print('in gen')
254
- print(enc_input_ids.size())
255
- print(dec_input_ids.size())
256
-
257
  assert enc_input_ids.size() == enc_attention_mask.size()
258
 
259
  if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
@@ -266,13 +291,11 @@ class PISCO(PreTrainedModel):
266
  f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
267
 
268
  compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
269
- inputs_embeds = self.replace_emb(enc_input_ids, compressed_embs, dec_input_ids)
270
 
271
- # Switch adapter if we are training two different ones:
272
  if 'decoder_adapter' in self.adapter_keys:
273
- self.decoder.set_adapter('decoder_adapter')
274
 
275
-
276
  output_ids = self.decoder.generate(
277
  inputs_embeds=inputs_embeds,
278
  attention_mask=dec_attention_mask,
@@ -280,18 +303,14 @@ class PISCO(PreTrainedModel):
280
  max_new_tokens=max_new_tokens
281
  )
282
 
283
- decoded = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
284
 
285
- return decoded
286
-
287
  def blend_prompt_and_memory_tokens(self, query: str):
288
  """
289
  Takes care of blending the prompt with the memory tokens:
290
  Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on)
291
  """
292
-
293
- mem_tokens_str = ''.join(self.tokenizer.mem_tokens)
294
- mem_tokens_str += self.tokenizer.sep_token
295
 
296
  # proper names for "eval" call, don't remove these lines
297
  docs = mem_tokens_str * self.generation_top_k
 
108
  def compress(self, enc_input_ids, enc_attention_mask):
109
  return self.compr_decoder(enc_input_ids, enc_attention_mask)
110
 
111
+ def replace_emb(self, compressed_embs, dec_input_ids):
112
  """
113
+ Create an input embedding vector combining the compressed_embs and the dec_input_ids
114
  """
115
+ indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)
116
+
117
+ input_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
118
+ num_embs = compressed_embs.size(1)
119
+ if self.sep:
120
+ slot_len = num_embs + 1
121
+ else:
122
+ slot_len = num_embs
123
+ # get first mem_token indices
124
+ first_mem_token_indices = torch.argmax((dec_input_ids == self.tokenizer.mem_token_ids[0]).int(), dim=1)
125
+ batch_size = input_embeds.size(0)
126
+ # for each example in batch, replace them with compressed embeddings
127
+ for i in range(batch_size):
128
+ for j in range(indices[i], indices[i + 1]):
129
+ start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
130
+ assert input_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \
131
+ f"{input_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}"
132
+ input_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
133
+
134
  return input_embeds
135
 
136
  def compr_decoder(self, input_ids, attention_mask):
 
143
  if 'encoder_adapter' in self.adapter_keys:
144
  self.decoder.set_adapter('encoder_adapter')
145
 
 
 
146
  emb = self.decoder(input_ids=input_ids,
147
  attention_mask=attention_mask,
148
  output_hidden_states=True).hidden_states[-1]
149
  mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device))
150
  return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
151
 
152
+ def prepare_encoder_inputs_to_decoder(self, texts, max_length):
153
+ inp_enc = [self.tokenizer.enc_token + self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in texts]
154
+ inp_enc = self.tokenizer(inp_enc, return_tensors='pt', padding="longest", max_length=max_length+3, truncation=True, add_special_tokens=False)
155
+ num_mem_tokens = 128 // self.compr_rate # hardcode size
 
 
 
 
 
 
156
  assert num_mem_tokens == len(self.tokenizer.mem_tokens)
157
  inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'],
158
  inp_enc['attention_mask'],
 
164
  def prepare_encoder_inputs(self, texts, max_length):
165
  return self.prepare_encoder_inputs_to_decoder(texts, max_length)
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def forward(self,
168
  enc_input_ids: torch.LongTensor = None,
169
  enc_attention_mask: torch.LongTensor = None,
 
191
 
192
  # Perform compression with gradient tracking
193
  compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
194
+ inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
195
 
196
  # decoding
197
  if 'decoder_adapter' in self.adapter_keys:
 
205
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
206
 
207
  def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
208
+ """
209
+ Generates answers from documents (via compression then decoding)
210
+ questions: list of string
211
+ documents: list of list of strings (they should all be of equal length: the nb of doc for each question)
212
+ """
213
  self.generation_top_k = len(documents[0])
214
  assert len(documents) == len(questions)
215
  assert all([len(context) == len(documents[0]) for context in documents])
216
  flat_documents = sum(documents, [])
217
 
218
  model_input = {}
219
+
220
+ # Creating encoder inputs:
221
  input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
222
  device = self.decoder.device
 
223
  model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device)
224
 
225
+ # Creating decoder inputs
226
  instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
 
227
  inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
 
228
  model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
229
 
230
+ # Generation
231
  return self.generate(model_input, max_new_tokens=max_new_tokens)
232
+
233
+ def generate_from_compressed_documents_and_questions(self, questions: list[str], compressed_documents: torch.Tensor, max_new_tokens: int = 128) -> list[str]:
234
+ """
235
+ Generates answers from compressed documents
236
+ questions: list of string
237
+ compressed_documents: torch tensor, its first dimension should be a multiple of len(questions)
238
+ """
239
+ print(compressed_documents.size(), len(questions))
240
+ self.generation_top_k = compressed_documents.size(0) // len(questions)
241
+ assert compressed_documents.size(0) % self.generation_top_k == 0, f"{compressed_documents.size(0)} {self.generation_top_k}"
242
+
243
+ # Creating decoder inputs
244
+ instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
245
+ inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
246
+ device = self.decoder.device
247
+ dec_input_ids, dec_attention_mask = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
248
+
249
+ # Creating input decoder embeddings from prompt + compressed documents
250
+ inputs_embeds = self.replace_emb(compressed_documents, dec_input_ids)
251
+
252
+ # Activating decoder generator:
253
+ if 'decoder_adapter' in self.adapter_keys:
254
+ self.decoder.set_adapter('decoder_adapter')
255
+
256
+ output_ids = self.decoder.generate(
257
+ inputs_embeds=inputs_embeds,
258
+ attention_mask=dec_attention_mask,
259
+ generation_config=self.generation_config,
260
+ max_new_tokens=max_new_tokens
261
+ )
262
+
263
+ # de-tokenizing
264
+ return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
265
 
266
  def compress_documents(self, documents: list[str]) -> torch.Tensor:
267
+ """
268
+ Compress a list of documents
269
+ """
270
  input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
271
  enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
272
  attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
 
273
  return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
274
 
275
  def generate(self, model_input, max_new_tokens=128):
276
+ """
277
+ Generation pipeline including compression + decoding from compressed
278
+ """
279
 
280
  enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
281
 
 
 
 
 
282
  assert enc_input_ids.size() == enc_attention_mask.size()
283
 
284
  if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
 
291
  f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
292
 
293
  compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
294
+ inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
295
 
 
296
  if 'decoder_adapter' in self.adapter_keys:
297
+ self.decoder.set_adapter('decoder_adapter')
298
 
 
299
  output_ids = self.decoder.generate(
300
  inputs_embeds=inputs_embeds,
301
  attention_mask=dec_attention_mask,
 
303
  max_new_tokens=max_new_tokens
304
  )
305
 
306
+ return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
307
 
 
 
308
  def blend_prompt_and_memory_tokens(self, query: str):
309
  """
310
  Takes care of blending the prompt with the memory tokens:
311
  Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on)
312
  """
313
+ mem_tokens_str = ''.join(self.tokenizer.mem_tokens) + self.tokenizer.sep_token
 
 
314
 
315
  # proper names for "eval" call, don't remove these lines
316
  docs = mem_tokens_str * self.generation_top_k