import json import logging import os from typing import Dict, List, Optional, Union import numpy as np import torch import torch.multiprocessing as mp from peft import PeftModel from torch import Tensor, device, nn from tqdm.autonotebook import tqdm, trange from transformers import ( AutoModel, AutoConfig, PretrainedConfig, AutoTokenizer, LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config, ) logger = logging.getLogger(__name__) def batch_to_device(batch, target_device: device): """ send a pytorch batch to a device (CPU/GPU) """ for key in batch: if isinstance(batch[key], Tensor): batch[key] = batch[key].to(target_device) return batch class LLMEncoder(nn.Module): def __init__( self, model: AutoModel, tokenizer: AutoTokenizer, pooling_mode: str = "weighted_mean", max_length: int = 512, doc_max_length: int = 400, skip_instruction: bool = True, ): super().__init__() self.model = model self.tokenizer = tokenizer self.pooling_mode = pooling_mode self.skip_instruction = skip_instruction self.max_length = max_length self.doc_max_length = doc_max_length self.config = model.config @classmethod def from_pretrained( self, base_model_name_or_path, peft_model_name_or_path=None, cache_dir=None, **kwargs, ): """ Load a pretrained model from a model identifier or path. Args: base_model_name_or_path: Model identifier or path to pretrained model. peft_model_name_or_path: Path to any PEFT models to apply. Returns: L3Prune model. """ # pop out encoder args keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] encoder_args = { key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None } tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, cache_dir=cache_dir) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" config = AutoConfig.from_pretrained(base_model_name_or_path) model = AutoModel.from_pretrained(base_model_name_or_path, cache_dir=cache_dir, **kwargs) if os.path.isdir(base_model_name_or_path) and os.path.exists( f"{base_model_name_or_path}/config.json" ): with open(f"{base_model_name_or_path}/config.json", "r") as fIn: config_dict = json.load(fIn) config = PretrainedConfig.from_dict(config_dict) model.config._name_or_path = config._name_or_path if peft_model_name_or_path is not None: model = PeftModel.from_pretrained( model, peft_model_name_or_path, ) model = model.merge_and_unload() config = {} if os.path.exists(f"{base_model_name_or_path}/l3prune_config.json"): with open(f"{base_model_name_or_path}/l3prune_config.json", "r") as fIn: l3prune_config = json.load(fIn) config.update(l3prune_config) for key, value in encoder_args.items(): config[key] = value return self(model=model, tokenizer=tokenizer, **config) def prune(self, percent_prune=0): """ Prune a model to a percentage of layers of the base model. If percent_prune is equal to or greater than 1, it is taken as the specific layer number to prune to. For example, if percent_prune=0.3, 30% of the layers will be pruned. If percent_prune=3, the model will be pruned to 3 layers. """ # take it as the specific layer number to prune to if percent_prune >= 1: new_num_layers = int(percent_prune) else: new_num_layers = int(self.model.config.num_hidden_layers * (1 - percent_prune)) print(f"Pruning to {new_num_layers} layer.") self.model.layers = self.model.layers[:new_num_layers] self.model.config.num_hidden_layers = new_num_layers def prepare_for_tokenization(self, text): if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": text = ( "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>" ) return text if self.model.config._name_or_path in [ "mistralai/Mistral-7B-Instruct-v0.2", "meta-llama/Llama-2-7b-chat-hf", ]: text = "[INST] " + text.strip() + " [/INST]" if self.model.config._name_or_path in [ "google/gemma-2-9b-it", ]: text = "user\n" + text.strip() + "" if self.model.config._name_or_path in [ "Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-7B-Instruct", ]: text = "<|im_start|>user\n" + text.strip() + "<|im_end|>" if self.pooling_mode == "eos_token": if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": text = text.strip() + "<|end_of_text|>" elif isinstance(self.model.config, LlamaConfig) or isinstance( self.model.config, MistralConfig ): text = text.strip() + " " elif isinstance(self.model.config, GemmaConfig): text = text.strip() + "" elif isinstance(self.model.config, Qwen2Config): text = text.strip() + "<|endoftext|>" return text def tokenize(self, texts): texts_2 = [] original_texts = [] for text in texts: t = text.split("!@#$%^&*()") texts_2.append(t[1] if len(t) > 1 else "") original_texts.append("".join(t)) original = self.tokenizer( original_texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length, ) embed_mask = None for t_i, t in enumerate(texts_2): ids = self.tokenizer( [t], return_tensors="pt", padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False, ) if embed_mask is None: e_m = torch.zeros_like(original["attention_mask"][t_i]) if len(ids["input_ids"][0]) > 0: e_m[-len(ids["input_ids"][0]) :] = torch.ones( len(ids["input_ids"][0]) ) embed_mask = e_m.unsqueeze(0) else: e_m = torch.zeros_like(original["attention_mask"][t_i]) if len(ids["input_ids"][0]) > 0: e_m[-len(ids["input_ids"][0]) :] = torch.ones( len(ids["input_ids"][0]) ) embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) original["embed_mask"] = embed_mask return original def _skip_instruction(self, sentence_feature): assert ( sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape ) sentence_feature["attention_mask"] = sentence_feature["embed_mask"] def forward(self, sentence_feature: Dict[str, Tensor]): embed_mask = None if "embed_mask" in sentence_feature: embed_mask = sentence_feature.pop("embed_mask") reps = self.model(**sentence_feature) sentence_feature["embed_mask"] = embed_mask return self.get_pooling(sentence_feature, reps.last_hidden_state) def get_pooling(self, features, last_hidden_states): # All models padded from left assert ( self.tokenizer.padding_side == "left" ), "Pooling modes are implemented for padding from left." if self.skip_instruction: self._skip_instruction(features) seq_lengths = features["attention_mask"].sum(dim=-1) if self.pooling_mode == "mean": return torch.stack( [ last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths) ], dim=0, ) elif self.pooling_mode == "weighted_mean": bs, l, _ = last_hidden_states.shape complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) for i, seq_l in enumerate(seq_lengths): if seq_l > 0: complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 complete_weights[i] /= torch.clamp( complete_weights[i].sum(), min=1e-9 ) return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": return last_hidden_states[:, -1] elif self.pooling_mode == "bos_token": return last_hidden_states[ features["input_ids"] == self.tokenizer.bos_token_id ] else: raise ValueError(f"{self.pooling_mode} is not implemented yet.") def _convert_to_str(self, instruction, text): tokenized_q = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False, ) tokenized_q_length = len(tokenized_q["input_ids"][0]) while tokenized_q_length > self.doc_max_length: reduction_ratio = self.doc_max_length / tokenized_q_length reduced_length = int(len(text.split()) * reduction_ratio) text = " ".join(text.split()[:reduced_length]) tokenized_q = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False, ) tokenized_q_length = len(tokenized_q["input_ids"][0]) return ( f"{instruction.strip()} !@#$%^&*(){text}" if instruction else f"!@#$%^&*(){text}" ) def encode( self, sentences: Union[str, List[str]], batch_size: int = 32, show_progress_bar: bool = True, convert_to_numpy: bool = False, convert_to_tensor: bool = False, device: Optional[str] = None, ): """ Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string. Args: sentences: sentence or sentences to encode. batch_size: batch size for turning sentence tokens into embeddings. show_progress_bar: whether to show progress bars during encoding steps. convert_to_numpy: If true, return numpy arrays instead of torch tensors. convert_to_tensor: If true, return torch tensors (default). device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified, the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports multiprocessing as currently implemented. Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation). """ if isinstance(sentences[0], str) and isinstance(sentences[-1], int): sentences = [sentences] # required for MEDI version of MTEB if isinstance(sentences[0], str): sentences = [[""] + [sentence] for sentence in sentences] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" concatenated_input_texts = [] for sentence in sentences: assert isinstance(sentence[0], str) assert isinstance(sentence[1], str) concatenated_input_texts.append( self._convert_to_str(sentence[0], sentence[1]) ) sentences = concatenated_input_texts self.eval() if convert_to_tensor: convert_to_numpy = False length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] all_embeddings = [] if torch.cuda.device_count() <= 1: # This branch also support mps devices self.to(device) for start_index in trange( 0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar, ): sentences_batch = sentences_sorted[ start_index : start_index + batch_size ] embeddings = self._encode( sentences_batch, device=device, convert_to_numpy=convert_to_numpy ) all_embeddings.append(embeddings) else: num_proc = torch.cuda.device_count() cuda_compatible_multiprocess = mp.get_context("spawn") with cuda_compatible_multiprocess.Pool(num_proc) as p: sentences_batches = [ sentences_sorted[start_index : start_index + batch_size] for start_index in range(0, len(sentences), batch_size) ] progress_bar = tqdm( total=len(sentences_batches), desc="Batches", disable=not show_progress_bar, ) results = [] def update(*args): progress_bar.update() for batch in sentences_batches: results.append( p.apply_async( self._encode, args=(batch, None, convert_to_numpy, True), callback=update, ) ) all_embeddings = [result.get() for result in results] progress_bar.close() all_embeddings = torch.cat(all_embeddings, dim=0) all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] all_embeddings = all_embeddings.to(torch.float32) if convert_to_numpy: all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) return all_embeddings def save(self, output_path, merge_before_save=False, save_config=True): if merge_before_save and isinstance(self.model, PeftModel): self.model = self.model.merge_and_unload() if hasattr(self.model, "_hf_peft_config_loaded"): self.model._hf_peft_config_loaded = False self.model.save_pretrained(output_path) self.tokenizer.save_pretrained(output_path) l3prune_config = { "pooling_mode": self.pooling_mode, "max_length": self.max_length, "doc_max_length": self.doc_max_length, "skip_instruction": self.skip_instruction, } if save_config: os.makedirs(output_path, exist_ok=True) with open(f"{output_path}/l3prune_config.json", "w") as fOut: json.dump(l3prune_config, fOut, indent=4) def _encode( self, sentences_batch, device: Optional[str] = None, convert_to_numpy: bool = False, multiprocessing=False, ): if multiprocessing: # multiprocessing only supports CUDA devices at this time, so we ignore the value of device # and use cuda:rank for the device rank = mp.current_process()._identity[0] if device is None and torch.cuda.is_available(): device = f"cuda:{rank % torch.cuda.device_count()}" self.to(device) features = self.tokenize( [self.prepare_for_tokenization(sentence) for sentence in sentences_batch] ) features = batch_to_device(features, device) with torch.no_grad(): embeddings = self.forward(features) embeddings = embeddings.detach() embeddings = embeddings.cpu() return embeddings def _text_length(self, text: Union[List[int], List[List[int]]]): """ Help function to get the length for the input text. Text can be either a string (which means a single text) a list of ints (which means a single tokenized text), or a tuple of list of ints (representing several text inputs to the model). """ if ( isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], int)) or len(text) == 0 ): # Single text, list of ints, or empty return len(text) if isinstance(text, dict): # {key: value} case return len(next(iter(text.values()))) elif not hasattr(text, "__len__"): # Object has no len() method return 1 else: return sum([len(t) for t in text]) def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, ) -> nn.Embedding: return self.model.resize_token_embeddings( new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of ) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): self.model.gradient_checkpointing_enable( gradient_checkpointing_kwargs=gradient_checkpointing_kwargs )