pisco-mistral / modelling_pisco.py
maxoul's picture
Upload PISCO
ac4eef4 verified
raw
history blame
17.7 kB
import warnings
import os
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PretrainedConfig, AutoConfig, GenerationConfig
from jinja2.exceptions import TemplateError
def add_memory_tokens_to_inputs(input_ids: torch.Tensor, attention_mask: torch.Tensor, n_mem_tokens: int, tokenizer):
"""
Concatenate the input ids with n_mem_tokens mem_tokens and update the corresponding attention mask
"""
assert len(tokenizer.mem_tokens) == n_mem_tokens, f"{len(tokenizer.mem_tokens)} VS {n_mem_tokens}"
mem_tokens = torch.stack([tokenizer.mem_token_ids_pt] * input_ids.size(0), 0)
assert len(mem_tokens.size()) == 2
assert len(mem_tokens) == input_ids.size(0)
assert len(mem_tokens[0]) == n_mem_tokens
#mem_tokens = torch.full((input_ids.size(0), n_mem_tokens), tokenizer.mem_token_id, dtype=torch.long)
input_ids = torch.cat([input_ids, mem_tokens], dim=1)
attention_mask = torch.cat([attention_mask, torch.ones(input_ids.size(0), n_mem_tokens)], dim=1)
return input_ids, attention_mask
class PISCOConfig(PretrainedConfig):
model_type = "PISCO"
def __init__(self,
decoder_model_name: str = "meta-llama/Llama-2-7b-chat-hf",
compr_rate: int = 16,
**kwargs):
super().__init__(**kwargs)
self.decoder_model_name = decoder_model_name # model name of decoder
self.compr_rate = compr_rate # compression rate
self.lora_r = 16
self.sep = True
class PISCO(PreTrainedModel):
config_class = PISCOConfig
def __init__(self, cfg):
super().__init__(cfg)
self.decoder_model_name = cfg.decoder_model_name
self.sep = cfg.sep
self.compr_rate = cfg.compr_rate
self.create_tokenizer(cfg)
# Base model config but we modify vocab size since we added tokens (mainly the mem tokens)
decoder_config = AutoConfig.from_pretrained(cfg.decoder_model_name)
decoder_config.vocab_size = len(self.tokenizer)
# Initializing placeholder model:
self.decoder = AutoModelForCausalLM.from_config(decoder_config,
attn_implementation='flash_attention_2',
torch_dtype=torch.bfloat16)
peft_config = self.get_peft_config(cfg)
self.adapter_keys = []
self.decoder.add_adapter(peft_config, 'decoder_adapter')
self.decoder.set_adapter('decoder_adapter')
self.adapter_keys.append('decoder_adapter')
self.decoder.add_adapter(peft_config, 'encoder_adapter')
self.adapter_keys.append('encoder_adapter')
self.generation_config = GenerationConfig(do_sample=False, top_p=None)
print('a')
# self.decoder = self.decoder.to('cuda')
print('b')
if torch.cuda.is_available():
print('c')
# self.decoder = self.decoder.to('cuda')
print('d')
def create_tokenizer(self, cfg):
self.tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
n_mem_tokens = 128 // cfg.compr_rate
mem_tokens = ['<MEM' + str(i) + '>' for i in range(n_mem_tokens)]
self.tokenizer.add_special_tokens({'additional_special_tokens': mem_tokens + ['<AE>', '<ENC>', '<SEP>']})
self.tokenizer.mem_tokens = mem_tokens
self.tokenizer.mem_token_ids = [self.tokenizer.convert_tokens_to_ids(elt) for elt in self.tokenizer.mem_tokens]
self.tokenizer.mem_token_ids_pt = torch.LongTensor(self.tokenizer.mem_token_ids) # required later on for operations on tensors
self.tokenizer.ae_token = '<AE>' # token for autoencoding on decoder side
self.tokenizer.ae_token_id = self.tokenizer.convert_tokens_to_ids('<AE>')
self.tokenizer.enc_token = '<ENC>' # token for autoencoding on compressor side
self.tokenizer.sep_token = '<SEP>' # sep token between document
self.tokenizer.sep_token_id = self.tokenizer.convert_tokens_to_ids('<SEP>')
# if pad token exists then use pad token, othrwise bos token
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.bos_token_id
def set_all_adapters(self):
if len(self.adapter_keys) > 0:
self.decoder.set_adapter(self.adapter_keys)
def get_peft_config(self, cfg: PISCOConfig) -> LoraConfig:
"""
Builds the peft config
"""
return LoraConfig(task_type="CAUSAL_LM", r=cfg.lora_r, lora_alpha=2* cfg.lora_r, target_modules='all-linear', lora_dropout=0.1)
def compress(self, enc_input_ids, enc_attention_mask):
return self.compr_decoder(enc_input_ids, enc_attention_mask)
def replace_emb(self, compressed_embs, dec_input_ids):
"""
Create an input embedding vector combining the compressed_embs and the dec_input_ids
"""
indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)
input_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
num_embs = compressed_embs.size(1)
if self.sep:
slot_len = num_embs + 1
else:
slot_len = num_embs
# get first mem_token indices
first_mem_token_indices = torch.argmax((dec_input_ids == self.tokenizer.mem_token_ids[0]).int(), dim=1)
batch_size = input_embeds.size(0)
# for each example in batch, replace them with compressed embeddings
for i in range(batch_size):
for j in range(indices[i], indices[i + 1]):
start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
assert input_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \
f"{input_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}"
input_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
return input_embeds
def compr_decoder(self, input_ids, attention_mask):
"""
Compression using the decoder
"""
assert input_ids.size() == attention_mask.size(), f"{input_ids.size()} vs {attention_mask.size()}"
# Switch adapter if we are training two different ones:
if 'encoder_adapter' in self.adapter_keys:
self.decoder.set_adapter('encoder_adapter')
emb = self.decoder(input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True).hidden_states[-1]
mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device))
return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
def prepare_encoder_inputs_to_decoder(self, texts, max_length):
inp_enc = [self.tokenizer.enc_token + self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in texts]
inp_enc = self.tokenizer(inp_enc, return_tensors='pt', padding="longest", max_length=max_length+3, truncation=True, add_special_tokens=False)
num_mem_tokens = 128 // self.compr_rate # hardcode size
assert num_mem_tokens == len(self.tokenizer.mem_tokens)
inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'],
inp_enc['attention_mask'],
num_mem_tokens,
tokenizer=self.tokenizer)
return inp_enc
def prepare_encoder_inputs(self, texts, max_length):
return self.prepare_encoder_inputs_to_decoder(texts, max_length)
def forward(self,
enc_input_ids: torch.LongTensor = None,
enc_attention_mask: torch.LongTensor = None,
dec_input_ids: torch.LongTensor = None,
dec_attention_mask: torch.LongTensor = None,
labels: torch.LongTensor = None):
"""
enc_input_ids: stores the contexts, should be flattened from all queries before input, can be of shape:
- (batch_size*generation_top_k, enc_token_length)
- (batch_size, generation_top_k, enc_token_length)
enc_attention_mask: attention mask of enc_input_ids, same shape as enc_input_ids
dec_input_ids: stores the prompts (including mem tokens), dimention (batch_size, dec_token_length)
dec_attention_mask: attention mask of dec_input_ids
"""
assert enc_input_ids.size() == enc_attention_mask.size(), f"{enc_input_ids.size()} vs {enc_attention_mask.size()}"
if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
batch_size, top_k, seq_length = enc_input_ids.size()
enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
# Here, we should have top_k times more elements in enc_input_ids than in dec_input_ids
assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \
f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
# Perform compression with gradient tracking
compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
# decoding
if 'decoder_adapter' in self.adapter_keys:
self.decoder.set_adapter('decoder_adapter')
decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
# At end of forward, we need to activate all adapters so that they are both trained...
self.set_all_adapters()
return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
"""
Generates answers from documents (via compression then decoding)
questions: list of string
documents: list of list of strings (they should all be of equal length: the nb of doc for each question)
"""
self.generation_top_k = len(documents[0])
assert len(documents) == len(questions)
assert all([len(context) == len(documents[0]) for context in documents])
flat_documents = sum(documents, [])
model_input = {}
# Creating encoder inputs:
input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
device = self.decoder.device
model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device)
# Creating decoder inputs
instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
# Generation
return self.generate(model_input, max_new_tokens=max_new_tokens)
def generate_from_compressed_documents_and_questions(self, questions: list[str], compressed_documents: torch.Tensor, max_new_tokens: int = 128) -> list[str]:
"""
Generates answers from compressed documents
questions: list of string
compressed_documents: torch tensor, its first dimension should be a multiple of len(questions)
"""
print(compressed_documents.size(), len(questions))
self.generation_top_k = compressed_documents.size(0) // len(questions)
assert compressed_documents.size(0) % self.generation_top_k == 0, f"{compressed_documents.size(0)} {self.generation_top_k}"
# Creating decoder inputs
instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
device = self.decoder.device
dec_input_ids, dec_attention_mask = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
# Creating input decoder embeddings from prompt + compressed documents
inputs_embeds = self.replace_emb(compressed_documents, dec_input_ids)
# Activating decoder generator:
if 'decoder_adapter' in self.adapter_keys:
self.decoder.set_adapter('decoder_adapter')
output_ids = self.decoder.generate(
inputs_embeds=inputs_embeds,
attention_mask=dec_attention_mask,
generation_config=self.generation_config,
max_new_tokens=max_new_tokens
)
# de-tokenizing
return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
def compress_documents(self, documents: list[str]) -> torch.Tensor:
"""
Compress a list of documents
"""
input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
def generate(self, model_input, max_new_tokens=128):
"""
Generation pipeline including compression + decoding from compressed
"""
enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
assert enc_input_ids.size() == enc_attention_mask.size()
if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
batch_size, top_k, seq_length = enc_input_ids.size()
enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
# Here, we should have top_k times more elements in enc_input_ids than in dec_input_ids
assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \
f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
if 'decoder_adapter' in self.adapter_keys:
self.decoder.set_adapter('decoder_adapter')
output_ids = self.decoder.generate(
inputs_embeds=inputs_embeds,
attention_mask=dec_attention_mask,
generation_config=self.generation_config,
max_new_tokens=max_new_tokens
)
return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
def blend_prompt_and_memory_tokens(self, query: str):
"""
Takes care of blending the prompt with the memory tokens:
Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on)
"""
mem_tokens_str = ''.join(self.tokenizer.mem_tokens) + self.tokenizer.sep_token
# proper names for "eval" call, don't remove these lines
docs = mem_tokens_str * self.generation_top_k
question = query
prompt_system = 'You are a helpful assistant. Your task is to extract relevant information from provided documents and to answer to questions as briefly as possible.'
prompt_user = f"Background:\n{docs}\n\nQuestion:{question}"
# Prepare the messages with system and user roles
messages = [
{"role": "system", "content": prompt_system},
{"role": "user", "content": prompt_user.replace(':\ ', ': ')}
]
# Attempt to apply the system role and catch if it's not supported
try:
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except TemplateError as e:
# Catch the error related to system role and handle it (e.g. gemma)
if "System role not supported" in str(e):
# Remove system role and proceed with only the user role
messages = [{"role": "user", "content": messages[0]['content'] + '\n' + messages[1]['content']}]
# Apply template again without system role
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
# Re-raise the exception if it's unrelated to system role
raise e
return prompt