Upload PISCO
Browse files- config.json +1 -1
- modelling_pisco.py +78 -59
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "/scratch/1/user/mlouis/calmar/pisco_hub_models/
|
3 |
"architectures": [
|
4 |
"PISCO"
|
5 |
],
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "/scratch/1/user/mlouis/calmar/pisco_hub_models/pisco-mistral",
|
3 |
"architectures": [
|
4 |
"PISCO"
|
5 |
],
|
modelling_pisco.py
CHANGED
@@ -108,12 +108,29 @@ class PISCO(PreTrainedModel):
|
|
108 |
def compress(self, enc_input_ids, enc_attention_mask):
|
109 |
return self.compr_decoder(enc_input_ids, enc_attention_mask)
|
110 |
|
111 |
-
def replace_emb(self,
|
112 |
"""
|
113 |
-
|
114 |
"""
|
115 |
-
indices = range(0,
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
return input_embeds
|
118 |
|
119 |
def compr_decoder(self, input_ids, attention_mask):
|
@@ -126,24 +143,16 @@ class PISCO(PreTrainedModel):
|
|
126 |
if 'encoder_adapter' in self.adapter_keys:
|
127 |
self.decoder.set_adapter('encoder_adapter')
|
128 |
|
129 |
-
print(self.decoder.device, input_ids.device, attention_mask.device)
|
130 |
-
|
131 |
emb = self.decoder(input_ids=input_ids,
|
132 |
attention_mask=attention_mask,
|
133 |
output_hidden_states=True).hidden_states[-1]
|
134 |
mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device))
|
135 |
return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
|
136 |
|
137 |
-
def prepare_encoder_inputs_to_decoder(self, texts, max_length
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
inp_enc = self.tokenizer(texts_to_encode, return_tensors='pt', padding='max_length', max_length=max_length + 8, truncation=True, add_special_tokens=False)
|
142 |
-
else:
|
143 |
-
inp_enc = [self.tokenizer.enc_token + self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in texts]
|
144 |
-
inp_enc = self.tokenizer(inp_enc, return_tensors='pt', padding="longest", max_length=max_length+3, truncation=True, add_special_tokens=False)
|
145 |
-
|
146 |
-
num_mem_tokens = 128 // self.compr_rate # maybe change that
|
147 |
assert num_mem_tokens == len(self.tokenizer.mem_tokens)
|
148 |
inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'],
|
149 |
inp_enc['attention_mask'],
|
@@ -155,28 +164,6 @@ class PISCO(PreTrainedModel):
|
|
155 |
def prepare_encoder_inputs(self, texts, max_length):
|
156 |
return self.prepare_encoder_inputs_to_decoder(texts, max_length)
|
157 |
|
158 |
-
def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
|
159 |
-
"""
|
160 |
-
Replace memory tokens in the decoder input to with the compressed embeddings
|
161 |
-
"""
|
162 |
-
inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
|
163 |
-
num_embs = compressed_embs.size(1)
|
164 |
-
if self.sep:
|
165 |
-
slot_len = num_embs + 1
|
166 |
-
else:
|
167 |
-
slot_len = num_embs
|
168 |
-
# get first mem_token indices
|
169 |
-
first_mem_token_indices = torch.argmax((dec_input_ids == self.tokenizer.mem_token_ids[0]).int(), dim=1)
|
170 |
-
batch_size = inputs_embeds.size(0)
|
171 |
-
# for each example in batch, replace them with compressed embeddings
|
172 |
-
for i in range(batch_size):
|
173 |
-
for j in range(indices[i], indices[i + 1]):
|
174 |
-
start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
|
175 |
-
assert inputs_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \
|
176 |
-
f"{inputs_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}"
|
177 |
-
inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
|
178 |
-
return inputs_embeds
|
179 |
-
|
180 |
def forward(self,
|
181 |
enc_input_ids: torch.LongTensor = None,
|
182 |
enc_attention_mask: torch.LongTensor = None,
|
@@ -204,7 +191,7 @@ class PISCO(PreTrainedModel):
|
|
204 |
|
205 |
# Perform compression with gradient tracking
|
206 |
compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
|
207 |
-
inputs_embeds = self.replace_emb(
|
208 |
|
209 |
# decoding
|
210 |
if 'decoder_adapter' in self.adapter_keys:
|
@@ -218,42 +205,80 @@ class PISCO(PreTrainedModel):
|
|
218 |
return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
|
219 |
|
220 |
def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
|
221 |
-
|
|
|
|
|
|
|
|
|
222 |
self.generation_top_k = len(documents[0])
|
223 |
assert len(documents) == len(questions)
|
224 |
assert all([len(context) == len(documents[0]) for context in documents])
|
225 |
flat_documents = sum(documents, [])
|
226 |
|
227 |
model_input = {}
|
|
|
|
|
228 |
input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
|
229 |
device = self.decoder.device
|
230 |
-
|
231 |
model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device)
|
232 |
|
|
|
233 |
instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
|
234 |
-
|
235 |
inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
|
236 |
-
|
237 |
model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
|
238 |
|
|
|
239 |
return self.generate(model_input, max_new_tokens=max_new_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
def compress_documents(self, documents: list[str]) -> torch.Tensor:
|
242 |
-
|
|
|
|
|
243 |
input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
|
244 |
enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
|
245 |
attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
|
246 |
-
print('yo', self.decoder.device, enc_input_ids.device, attention_mask.device)
|
247 |
return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
|
248 |
|
249 |
def generate(self, model_input, max_new_tokens=128):
|
|
|
|
|
|
|
250 |
|
251 |
enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
|
252 |
|
253 |
-
print('in gen')
|
254 |
-
print(enc_input_ids.size())
|
255 |
-
print(dec_input_ids.size())
|
256 |
-
|
257 |
assert enc_input_ids.size() == enc_attention_mask.size()
|
258 |
|
259 |
if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
|
@@ -266,13 +291,11 @@ class PISCO(PreTrainedModel):
|
|
266 |
f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
|
267 |
|
268 |
compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
|
269 |
-
inputs_embeds = self.replace_emb(
|
270 |
|
271 |
-
# Switch adapter if we are training two different ones:
|
272 |
if 'decoder_adapter' in self.adapter_keys:
|
273 |
-
self.decoder.set_adapter('decoder_adapter')
|
274 |
|
275 |
-
|
276 |
output_ids = self.decoder.generate(
|
277 |
inputs_embeds=inputs_embeds,
|
278 |
attention_mask=dec_attention_mask,
|
@@ -280,18 +303,14 @@ class PISCO(PreTrainedModel):
|
|
280 |
max_new_tokens=max_new_tokens
|
281 |
)
|
282 |
|
283 |
-
|
284 |
|
285 |
-
return decoded
|
286 |
-
|
287 |
def blend_prompt_and_memory_tokens(self, query: str):
|
288 |
"""
|
289 |
Takes care of blending the prompt with the memory tokens:
|
290 |
Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on)
|
291 |
"""
|
292 |
-
|
293 |
-
mem_tokens_str = ''.join(self.tokenizer.mem_tokens)
|
294 |
-
mem_tokens_str += self.tokenizer.sep_token
|
295 |
|
296 |
# proper names for "eval" call, don't remove these lines
|
297 |
docs = mem_tokens_str * self.generation_top_k
|
|
|
108 |
def compress(self, enc_input_ids, enc_attention_mask):
|
109 |
return self.compr_decoder(enc_input_ids, enc_attention_mask)
|
110 |
|
111 |
+
def replace_emb(self, compressed_embs, dec_input_ids):
|
112 |
"""
|
113 |
+
Create an input embedding vector combining the compressed_embs and the dec_input_ids
|
114 |
"""
|
115 |
+
indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)
|
116 |
+
|
117 |
+
input_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
|
118 |
+
num_embs = compressed_embs.size(1)
|
119 |
+
if self.sep:
|
120 |
+
slot_len = num_embs + 1
|
121 |
+
else:
|
122 |
+
slot_len = num_embs
|
123 |
+
# get first mem_token indices
|
124 |
+
first_mem_token_indices = torch.argmax((dec_input_ids == self.tokenizer.mem_token_ids[0]).int(), dim=1)
|
125 |
+
batch_size = input_embeds.size(0)
|
126 |
+
# for each example in batch, replace them with compressed embeddings
|
127 |
+
for i in range(batch_size):
|
128 |
+
for j in range(indices[i], indices[i + 1]):
|
129 |
+
start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
|
130 |
+
assert input_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \
|
131 |
+
f"{input_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}"
|
132 |
+
input_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
|
133 |
+
|
134 |
return input_embeds
|
135 |
|
136 |
def compr_decoder(self, input_ids, attention_mask):
|
|
|
143 |
if 'encoder_adapter' in self.adapter_keys:
|
144 |
self.decoder.set_adapter('encoder_adapter')
|
145 |
|
|
|
|
|
146 |
emb = self.decoder(input_ids=input_ids,
|
147 |
attention_mask=attention_mask,
|
148 |
output_hidden_states=True).hidden_states[-1]
|
149 |
mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device))
|
150 |
return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
|
151 |
|
152 |
+
def prepare_encoder_inputs_to_decoder(self, texts, max_length):
|
153 |
+
inp_enc = [self.tokenizer.enc_token + self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in texts]
|
154 |
+
inp_enc = self.tokenizer(inp_enc, return_tensors='pt', padding="longest", max_length=max_length+3, truncation=True, add_special_tokens=False)
|
155 |
+
num_mem_tokens = 128 // self.compr_rate # hardcode size
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
assert num_mem_tokens == len(self.tokenizer.mem_tokens)
|
157 |
inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'],
|
158 |
inp_enc['attention_mask'],
|
|
|
164 |
def prepare_encoder_inputs(self, texts, max_length):
|
165 |
return self.prepare_encoder_inputs_to_decoder(texts, max_length)
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
def forward(self,
|
168 |
enc_input_ids: torch.LongTensor = None,
|
169 |
enc_attention_mask: torch.LongTensor = None,
|
|
|
191 |
|
192 |
# Perform compression with gradient tracking
|
193 |
compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
|
194 |
+
inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
|
195 |
|
196 |
# decoding
|
197 |
if 'decoder_adapter' in self.adapter_keys:
|
|
|
205 |
return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
|
206 |
|
207 |
def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
|
208 |
+
"""
|
209 |
+
Generates answers from documents (via compression then decoding)
|
210 |
+
questions: list of string
|
211 |
+
documents: list of list of strings (they should all be of equal length: the nb of doc for each question)
|
212 |
+
"""
|
213 |
self.generation_top_k = len(documents[0])
|
214 |
assert len(documents) == len(questions)
|
215 |
assert all([len(context) == len(documents[0]) for context in documents])
|
216 |
flat_documents = sum(documents, [])
|
217 |
|
218 |
model_input = {}
|
219 |
+
|
220 |
+
# Creating encoder inputs:
|
221 |
input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
|
222 |
device = self.decoder.device
|
|
|
223 |
model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device)
|
224 |
|
225 |
+
# Creating decoder inputs
|
226 |
instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
|
|
|
227 |
inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
|
|
|
228 |
model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
|
229 |
|
230 |
+
# Generation
|
231 |
return self.generate(model_input, max_new_tokens=max_new_tokens)
|
232 |
+
|
233 |
+
def generate_from_compressed_documents_and_questions(self, questions: list[str], compressed_documents: torch.Tensor, max_new_tokens: int = 128) -> list[str]:
|
234 |
+
"""
|
235 |
+
Generates answers from compressed documents
|
236 |
+
questions: list of string
|
237 |
+
compressed_documents: torch tensor, its first dimension should be a multiple of len(questions)
|
238 |
+
"""
|
239 |
+
print(compressed_documents.size(), len(questions))
|
240 |
+
self.generation_top_k = compressed_documents.size(0) // len(questions)
|
241 |
+
assert compressed_documents.size(0) % self.generation_top_k == 0, f"{compressed_documents.size(0)} {self.generation_top_k}"
|
242 |
+
|
243 |
+
# Creating decoder inputs
|
244 |
+
instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
|
245 |
+
inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
|
246 |
+
device = self.decoder.device
|
247 |
+
dec_input_ids, dec_attention_mask = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
|
248 |
+
|
249 |
+
# Creating input decoder embeddings from prompt + compressed documents
|
250 |
+
inputs_embeds = self.replace_emb(compressed_documents, dec_input_ids)
|
251 |
+
|
252 |
+
# Activating decoder generator:
|
253 |
+
if 'decoder_adapter' in self.adapter_keys:
|
254 |
+
self.decoder.set_adapter('decoder_adapter')
|
255 |
+
|
256 |
+
output_ids = self.decoder.generate(
|
257 |
+
inputs_embeds=inputs_embeds,
|
258 |
+
attention_mask=dec_attention_mask,
|
259 |
+
generation_config=self.generation_config,
|
260 |
+
max_new_tokens=max_new_tokens
|
261 |
+
)
|
262 |
+
|
263 |
+
# de-tokenizing
|
264 |
+
return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
265 |
|
266 |
def compress_documents(self, documents: list[str]) -> torch.Tensor:
|
267 |
+
"""
|
268 |
+
Compress a list of documents
|
269 |
+
"""
|
270 |
input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
|
271 |
enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
|
272 |
attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
|
|
|
273 |
return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
|
274 |
|
275 |
def generate(self, model_input, max_new_tokens=128):
|
276 |
+
"""
|
277 |
+
Generation pipeline including compression + decoding from compressed
|
278 |
+
"""
|
279 |
|
280 |
enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
|
281 |
|
|
|
|
|
|
|
|
|
282 |
assert enc_input_ids.size() == enc_attention_mask.size()
|
283 |
|
284 |
if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
|
|
|
291 |
f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
|
292 |
|
293 |
compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
|
294 |
+
inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
|
295 |
|
|
|
296 |
if 'decoder_adapter' in self.adapter_keys:
|
297 |
+
self.decoder.set_adapter('decoder_adapter')
|
298 |
|
|
|
299 |
output_ids = self.decoder.generate(
|
300 |
inputs_embeds=inputs_embeds,
|
301 |
attention_mask=dec_attention_mask,
|
|
|
303 |
max_new_tokens=max_new_tokens
|
304 |
)
|
305 |
|
306 |
+
return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
307 |
|
|
|
|
|
308 |
def blend_prompt_and_memory_tokens(self, query: str):
|
309 |
"""
|
310 |
Takes care of blending the prompt with the memory tokens:
|
311 |
Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on)
|
312 |
"""
|
313 |
+
mem_tokens_str = ''.join(self.tokenizer.mem_tokens) + self.tokenizer.sep_token
|
|
|
|
|
314 |
|
315 |
# proper names for "eval" call, don't remove these lines
|
316 |
docs = mem_tokens_str * self.generation_top_k
|