File size: 17,371 Bytes
9e5c9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac4eef4
9e5c9e5
ac4eef4
9e5c9e5
ac4eef4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e5c9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac4eef4
 
 
 
9e5c9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac4eef4
9e5c9e5
 
 
 
 
 
 
 
 
 
 
 
 
ac4eef4
 
 
 
 
9e5c9e5
 
 
 
 
 
ac4eef4
 
9e5c9e5
 
 
 
ac4eef4
9e5c9e5
 
 
 
ac4eef4
9e5c9e5
ac4eef4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e5c9e5
 
ac4eef4
 
 
9e5c9e5
 
 
 
 
 
ac4eef4
 
 
9e5c9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac4eef4
9e5c9e5
 
ac4eef4
9e5c9e5
 
 
 
 
 
 
 
ac4eef4
9e5c9e5
 
 
 
 
 
ac4eef4
9e5c9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import warnings
import os
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PretrainedConfig, AutoConfig, GenerationConfig
from jinja2.exceptions import TemplateError


def add_memory_tokens_to_inputs(input_ids: torch.Tensor, attention_mask: torch.Tensor, n_mem_tokens: int, tokenizer):
    """
    Concatenate the input ids with n_mem_tokens mem_tokens and update the corresponding attention mask
    """
    assert len(tokenizer.mem_tokens) == n_mem_tokens, f"{len(tokenizer.mem_tokens)} VS {n_mem_tokens}"
    mem_tokens = torch.stack([tokenizer.mem_token_ids_pt] * input_ids.size(0), 0)
    assert len(mem_tokens.size()) == 2
    assert len(mem_tokens) == input_ids.size(0)
    assert len(mem_tokens[0]) == n_mem_tokens
    #mem_tokens = torch.full((input_ids.size(0), n_mem_tokens), tokenizer.mem_token_id, dtype=torch.long)
    input_ids = torch.cat([input_ids, mem_tokens], dim=1)
    attention_mask = torch.cat([attention_mask, torch.ones(input_ids.size(0), n_mem_tokens)], dim=1)
    return input_ids, attention_mask


class PISCOConfig(PretrainedConfig):

    model_type = "PISCO"
    def __init__(self,
                decoder_model_name: str = "meta-llama/Llama-2-7b-chat-hf",
                compr_rate: int = 16,
                **kwargs):
        super().__init__(**kwargs)

        self.decoder_model_name = decoder_model_name # model name of decoder
        self.compr_rate = compr_rate # compression rate
        self.lora_r = 16
        self.sep = True
        
        
class PISCO(PreTrainedModel):
    config_class = PISCOConfig
    def __init__(self, cfg):
        super().__init__(cfg)
        self.decoder_model_name = cfg.decoder_model_name
        self.sep = cfg.sep
        self.compr_rate = cfg.compr_rate
        
        self.create_tokenizer(cfg)
        
        # Base model config but we modify vocab size since we added tokens (mainly the mem tokens)
        decoder_config = AutoConfig.from_pretrained(cfg.decoder_model_name)
        decoder_config.vocab_size = len(self.tokenizer)
        
        # Initializing placeholder model:
        self.decoder = AutoModelForCausalLM.from_config(decoder_config, 
                                                        attn_implementation='flash_attention_2', 
                                                        torch_dtype=torch.bfloat16)
        
        peft_config = self.get_peft_config(cfg)

        self.adapter_keys = []
        self.decoder.add_adapter(peft_config, 'decoder_adapter')
        self.decoder.set_adapter('decoder_adapter')
        self.adapter_keys.append('decoder_adapter')
        self.decoder.add_adapter(peft_config, 'encoder_adapter')
        self.adapter_keys.append('encoder_adapter')
        
        self.generation_config = GenerationConfig(do_sample=False, top_p=None)

    def create_tokenizer(self, cfg):
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
        
        n_mem_tokens = 128 // cfg.compr_rate
        mem_tokens = ['<MEM' + str(i) + '>' for i in range(n_mem_tokens)]
        self.tokenizer.add_special_tokens({'additional_special_tokens': mem_tokens + ['<AE>', '<ENC>', '<SEP>']}) 
        self.tokenizer.mem_tokens = mem_tokens
        
        self.tokenizer.mem_token_ids = [self.tokenizer.convert_tokens_to_ids(elt) for elt in self.tokenizer.mem_tokens]
        self.tokenizer.mem_token_ids_pt = torch.LongTensor(self.tokenizer.mem_token_ids) # required later on for operations on tensors
        
        self.tokenizer.ae_token = '<AE>' # token for autoencoding on decoder side
        self.tokenizer.ae_token_id = self.tokenizer.convert_tokens_to_ids('<AE>')
        self.tokenizer.enc_token = '<ENC>' # token for autoencoding on compressor side
        self.tokenizer.sep_token = '<SEP>' # sep token between document
        self.tokenizer.sep_token_id = self.tokenizer.convert_tokens_to_ids('<SEP>')

        # if pad token exists then use pad token, othrwise bos token
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.bos_token_id

    def set_all_adapters(self):
        if len(self.adapter_keys) > 0:
            self.decoder.set_adapter(self.adapter_keys)

    def get_peft_config(self, cfg: PISCOConfig) -> LoraConfig:
        """
        Builds the peft config
        """
        return LoraConfig(task_type="CAUSAL_LM", r=cfg.lora_r, lora_alpha=2* cfg.lora_r, target_modules='all-linear', lora_dropout=0.1)
            
    def compress(self, enc_input_ids, enc_attention_mask):
        return self.compr_decoder(enc_input_ids, enc_attention_mask)

    def replace_emb(self, compressed_embs, dec_input_ids):
        """
        Create an input embedding vector combining the compressed_embs and the dec_input_ids
        """
        indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)            
        
        input_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
        num_embs = compressed_embs.size(1)
        if self.sep:
            slot_len = num_embs + 1
        else:
            slot_len = num_embs
        # get first mem_token indices
        first_mem_token_indices = torch.argmax((dec_input_ids == self.tokenizer.mem_token_ids[0]).int(), dim=1)
        batch_size = input_embeds.size(0)
        # for each example in batch, replace them with compressed embeddings
        for i in range(batch_size):
            for j in range(indices[i], indices[i + 1]):
                start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
                assert input_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \
                    f"{input_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}"
                input_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]

        return input_embeds

    def compr_decoder(self, input_ids, attention_mask):
        """
        Compression using the decoder
        """
        assert input_ids.size() == attention_mask.size(), f"{input_ids.size()} vs {attention_mask.size()}"
        
        # Switch adapter if we are training two different ones:
        if 'encoder_adapter' in self.adapter_keys:
            self.decoder.set_adapter('encoder_adapter')
            
        emb = self.decoder(input_ids=input_ids,
                           attention_mask=attention_mask,
                           output_hidden_states=True).hidden_states[-1]
        mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device))
        return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
    
    def prepare_encoder_inputs_to_decoder(self, texts, max_length):
        inp_enc = [self.tokenizer.enc_token + self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in texts]
        inp_enc = self.tokenizer(inp_enc, return_tensors='pt', padding="longest", max_length=max_length+3, truncation=True, add_special_tokens=False)
        num_mem_tokens = 128 // self.compr_rate  # hardcode size
        assert num_mem_tokens == len(self.tokenizer.mem_tokens)
        inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'], 
                                                                                        inp_enc['attention_mask'], 
                                                                                        num_mem_tokens, 
                                                                                        tokenizer=self.tokenizer)
        
        return inp_enc
    
    def prepare_encoder_inputs(self, texts, max_length):
        return self.prepare_encoder_inputs_to_decoder(texts, max_length)
        
    def forward(self,
                enc_input_ids: torch.LongTensor = None,
                enc_attention_mask: torch.LongTensor = None,
                dec_input_ids: torch.LongTensor = None,
                dec_attention_mask: torch.LongTensor = None,
                labels: torch.LongTensor = None):
        """
        enc_input_ids: stores the contexts, should be flattened from all queries before input, can be of shape:
            - (batch_size*generation_top_k, enc_token_length)
            - (batch_size, generation_top_k, enc_token_length)
        enc_attention_mask: attention mask of enc_input_ids, same shape as enc_input_ids
        dec_input_ids: stores the prompts (including mem tokens), dimention (batch_size, dec_token_length)
        dec_attention_mask: attention mask of dec_input_ids
        """ 
        assert enc_input_ids.size() == enc_attention_mask.size(), f"{enc_input_ids.size()} vs {enc_attention_mask.size()}"
        
        if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
            batch_size, top_k, seq_length = enc_input_ids.size()
            enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
            enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
        
        # Here, we should have top_k times more elements in enc_input_ids than in dec_input_ids
        assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \
            f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
            
        # Perform compression with gradient tracking
        compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
        inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)

        # decoding
        if 'decoder_adapter' in self.adapter_keys:
            self.decoder.set_adapter('decoder_adapter')

        decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)

        # At end of forward, we need to activate all adapters so that they are both trained...
        self.set_all_adapters()

        return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
    
    def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
        """
        Generates answers from documents (via compression then decoding)
        questions: list of string
        documents: list of list of strings (they should all be of equal length: the nb of doc for each question)
        """
        self.generation_top_k = len(documents[0])
        assert len(documents) == len(questions)
        assert all([len(context) == len(documents[0]) for context in documents])
        flat_documents = sum(documents, [])
        
        model_input = {}
        
        # Creating encoder inputs:
        input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
        device = self.decoder.device
        model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device)
        
        # Creating decoder inputs
        instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
        inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True,  max_length=2048)
        model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
        
        # Generation
        return self.generate(model_input, max_new_tokens=max_new_tokens)
    
    def generate_from_compressed_documents_and_questions(self, questions: list[str], compressed_documents: torch.Tensor, max_new_tokens: int = 128) -> list[str]:
        """
        Generates answers from compressed documents
        questions: list of string
        compressed_documents: torch tensor, its first dimension should be a multiple of len(questions)
        """
        self.generation_top_k = compressed_documents.size(0) // len(questions)
        assert compressed_documents.size(0) % self.generation_top_k == 0, f"{compressed_documents.size(0)} {self.generation_top_k}"
        
        # Creating decoder inputs
        instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
        inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True,  max_length=2048)
        device = self.decoder.device
        dec_input_ids, dec_attention_mask = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)

        # Creating input decoder embeddings from prompt + compressed documents
        inputs_embeds = self.replace_emb(compressed_documents, dec_input_ids)
        
        # Activating decoder generator:
        if 'decoder_adapter' in self.adapter_keys:
            self.decoder.set_adapter('decoder_adapter')
            
        output_ids = self.decoder.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=dec_attention_mask,
            generation_config=self.generation_config,
            max_new_tokens=max_new_tokens
            )
        
        # de-tokenizing
        return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        
    def compress_documents(self, documents: list[str]) -> torch.Tensor:
        """
        Compress a list of documents
        """
        input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
        enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
        attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
        return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)

    def generate(self, model_input, max_new_tokens=128):
        """
        Generation pipeline including compression + decoding from compressed
        """

        enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
        
        assert enc_input_ids.size() == enc_attention_mask.size()
        
        if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
            batch_size, top_k, seq_length = enc_input_ids.size()
            enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
            enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
            
        # Here, we should have top_k times more elements in enc_input_ids than in dec_input_ids
        assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \
            f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
            
        compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
        inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
        
        if 'decoder_adapter' in self.adapter_keys:
            self.decoder.set_adapter('decoder_adapter') 
            
        output_ids = self.decoder.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=dec_attention_mask,
            generation_config=self.generation_config,
            max_new_tokens=max_new_tokens
            )

        return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        
    def blend_prompt_and_memory_tokens(self, query: str):
        """
        Takes care of blending the prompt with the memory tokens:
        Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on)
        """        
        mem_tokens_str = ''.join(self.tokenizer.mem_tokens) + self.tokenizer.sep_token
        
        # proper names for "eval" call, don't remove these lines
        docs = mem_tokens_str * self.generation_top_k
        question = query
        
        prompt_system = 'You are a helpful assistant. Your task is to extract relevant information from provided documents and to answer to questions as briefly as possible.'
        prompt_user = f"Background:\n{docs}\n\nQuestion:{question}"
        
        # Prepare the messages with system and user roles
        messages = [
            {"role": "system", "content": prompt_system},
            {"role": "user", "content": prompt_user.replace(':\ ', ': ')}
        ]

        # Attempt to apply the system role and catch if it's not supported
        try:
            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            
        except TemplateError as e:
            # Catch the error related to system role and handle it (e.g. gemma)
            if "System role not supported" in str(e):
                # Remove system role and proceed with only the user role
                messages = [{"role": "user", "content": messages[0]['content'] + '\n' + messages[1]['content']}]
                # Apply template again without system role
                prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            else:
                # Re-raise the exception if it's unrelated to system role
                raise e

        return prompt