# -*- coding: utf-8 -*- """ @author:XuMing(xuming624@qq.com) @description: """ import math import os from dataclasses import dataclass, field from glob import glob from typing import Any, List, Union, Optional, Dict import torch from datasets import load_dataset from loguru import logger from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training from sklearn.metrics import mean_squared_error, mean_absolute_error from torch.utils.data import Dataset from transformers import ( AutoConfig, PreTrainedTokenizerBase, BloomForSequenceClassification, LlamaForSequenceClassification, LlamaTokenizer, BloomTokenizerFast, AlbertForSequenceClassification, BertForSequenceClassification, BertTokenizer, AutoTokenizer, RobertaForSequenceClassification, AutoModelForSequenceClassification, RobertaTokenizer, HfArgumentParser, Trainer, TrainingArguments, set_seed, ) from transformers.trainer import TRAINING_ARGS_NAME MODEL_CLASSES = { "bert": (AutoConfig, BertForSequenceClassification, BertTokenizer), "roberta": (AutoConfig, RobertaForSequenceClassification, RobertaTokenizer), "albert": (AutoConfig, AlbertForSequenceClassification, AutoTokenizer), "bloom": (AutoConfig, BloomForSequenceClassification, BloomTokenizerFast), "llama": (AutoConfig, LlamaForSequenceClassification, LlamaTokenizer), "auto": (AutoConfig, AutoModelForSequenceClassification, AutoTokenizer), } @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. """ model_type: str = field( default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())} ) model_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." ) }, ) tokenizer_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The tokenizer for weights initialization.Don't set if you want to train a model from scratch." ) }, ) load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."}) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) use_fast_tokenizer: bool = field( default=False, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, ) torch_dtype: Optional[str] = field( default=None, metadata={ "help": ( "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights." ), "choices": ["auto", "bfloat16", "float16", "float32"], }, ) device_map: Optional[str] = field( default="auto", metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "}, ) trust_remote_code: bool = field( default=True, metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."}, ) def __post_init__(self): if self.model_type is None: raise ValueError( "You must specify a valid model_type to run training. Available model types are " + ", ".join( MODEL_CLASSES.keys())) if self.model_name_or_path is None: raise ValueError("You must specify a valid model_name_or_path to run training.") @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."}) validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, ) max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"}) max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"}) max_train_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ) }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." ) }, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) validation_split_percentage: Optional[int] = field( default=1, metadata={ "help": "The percentage of the train set used as validation set in case there's no validation split" }, ) preprocessing_num_workers: Optional[int] = field( default=4, metadata={"help": "The number of processes to use for the preprocessing."}, ) @dataclass class PeftArguments(TrainingArguments): use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"}) target_modules: Optional[str] = field(default="all") lora_rank: Optional[int] = field(default=8) lora_dropout: Optional[float] = field(default=0.05) lora_alpha: Optional[float] = field(default=32.0) modules_to_save: Optional[str] = field(default=None) peft_path: Optional[str] = field(default=None) def compute_metrics(eval_preds): preds, labels = eval_preds # Here, predictions is rewards_chosen and rewards_rejected. if isinstance(preds, torch.Tensor): preds = preds.detach().cpu().numpy() if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() # MSE mse = mean_squared_error(labels, preds) # MAE mae = mean_absolute_error(labels, preds) return {"mse": mse, "mae": mae} @dataclass class RewardDataCollatorWithPadding: """We need to define a special data collator that batches the data in our chosen vs rejected format""" tokenizer: PreTrainedTokenizerBase padding: Union[bool, str] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: features_chosen = [] features_rejected = [] for feature in features: features_chosen.append( { "input_ids": feature["input_ids_chosen"], "attention_mask": feature["attention_mask_chosen"], } ) features_rejected.append( { "input_ids": feature["input_ids_rejected"], "attention_mask": feature["attention_mask_rejected"], } ) batch_chosen = self.tokenizer.pad( features_chosen, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) batch_rejected = self.tokenizer.pad( features_rejected, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) batch = { "input_ids_chosen": batch_chosen["input_ids"], "attention_mask_chosen": batch_chosen["attention_mask"], "input_ids_rejected": batch_rejected["input_ids"], "attention_mask_rejected": batch_rejected["attention_mask"], "return_loss": True, } return batch class RewardTrainer(Trainer): """ Trainer for reward models Define how to compute the reward loss. Use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155 """ def compute_loss(self, model, inputs, return_outputs=False): rewards_chosen = model(input_ids=inputs["input_ids_chosen"], attention_mask=inputs["attention_mask_chosen"])[0] rewards_rejected = model(input_ids=inputs["input_ids_rejected"], attention_mask=inputs["attention_mask_rejected"])[0] loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() if return_outputs: return loss, {"rewards_chosen": rewards_chosen, "rewards_rejected": rewards_rejected} return loss def evaluate( self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> Dict[str, float]: if eval_dataset is None: eval_dataset = self.eval_dataset return super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): # Prepare inputs for chosen and rejected separately device = model.device inputs_chosen = { "input_ids": inputs["input_ids_chosen"].to(device), "attention_mask": inputs["attention_mask_chosen"].to(device), } outputs_chosen = model(**inputs_chosen) rewards_chosen = outputs_chosen.logits.detach() inputs_rejected = { "input_ids": inputs["input_ids_rejected"].to(device), "attention_mask": inputs["attention_mask_rejected"].to(device), } outputs_rejected = model(**inputs_rejected) rewards_rejected = outputs_rejected.logits.detach() # Keep the compute_loss method loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() if prediction_loss_only: return (loss, None, None) return (loss, rewards_chosen, rewards_rejected) def save_model(self, output_dir=None, _internal_call=False): """Save the LoRA model.""" os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) self.model.save_pretrained(output_dir) def save_model(output_dir, model, tokenizer, args): """Save the model and the tokenizer.""" os.makedirs(output_dir, exist_ok=True) # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, "module") else model model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME)) class CastOutputToFloat(torch.nn.Sequential): """Cast the output of the model to float""" def forward(self, x): return super().forward(x).to(torch.float32) def print_trainable_parameters(model): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" ) def find_all_linear_names(peft_model, int4=False, int8=False): cls = torch.nn.Linear if int4 or int8: import bitsandbytes as bnb if int4: cls = bnb.nn.Linear4bit elif int8: cls = bnb.nn.Linear8bitLt lora_module_names = set() for name, module in peft_model.named_modules(): if isinstance(module, cls): # last layer is not add to lora_module_names if 'lm_head' in name: continue if 'score' in name: continue names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) return sorted(lora_module_names) def main(): parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() logger.info(f"Model args: {model_args}") logger.info(f"Data args: {data_args}") logger.info(f"Training args: {training_args}") logger.info( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set seed before initializing model. set_seed(training_args.seed) # Load model if not model_args.model_type: raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.") config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type] if model_args.model_name_or_path: torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) world_size = int(os.environ.get("WORLD_SIZE", 1)) if world_size > 1: model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0} config = config_class.from_pretrained( model_args.model_name_or_path, num_labels=1, torch_dtype=torch_dtype, trust_remote_code=model_args.trust_remote_code, cache_dir=model_args.cache_dir ) if model_args.model_type in ['bloom', 'llama']: model = model_class.from_pretrained( model_args.model_name_or_path, config=config, load_in_8bit=model_args.load_in_8bit, device_map=model_args.device_map, trust_remote_code=model_args.trust_remote_code, ) model.score = CastOutputToFloat(model.score) else: model = model_class.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, ignore_mismatched_sizes=True ) model.to(training_args.device) else: raise ValueError(f"Error, model_name_or_path is None, RM must be loaded from a pre-trained model") # Load tokenizer if model_args.model_type == "bloom": model_args.use_fast_tokenizer = True tokenizer_kwargs = { "cache_dir": model_args.cache_dir, "use_fast": model_args.use_fast_tokenizer, "trust_remote_code": model_args.trust_remote_code, } tokenizer_name_or_path = model_args.tokenizer_name_or_path if not tokenizer_name_or_path: tokenizer_name_or_path = model_args.model_name_or_path tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = 0 if training_args.use_peft: if training_args.peft_path is not None: logger.info(f"Peft from pre-trained model: {training_args.peft_path}") model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True) else: logger.info("Init new peft model") target_modules = training_args.target_modules.split(',') if training_args.target_modules else None if target_modules and 'all' in target_modules: target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit) modules_to_save = training_args.modules_to_save if modules_to_save is not None: modules_to_save = modules_to_save.split(',') logger.info(f"Peft target_modules: {target_modules}") logger.info(f"Peft lora_rank: {training_args.lora_rank}") peft_config = LoraConfig( task_type=TaskType.SEQ_CLS, target_modules=target_modules, inference_mode=False, r=training_args.lora_rank, lora_alpha=training_args.lora_alpha, lora_dropout=training_args.lora_dropout, modules_to_save=modules_to_save) model = get_peft_model(model, peft_config) if model_args.load_in_8bit: model = prepare_model_for_int8_training(model) model.print_trainable_parameters() else: logger.info("Full parameters training") print_trainable_parameters(model) # Get reward dataset for tuning the reward model. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, ) if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) raw_datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: data_files = {} if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir): train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob( f'{data_args.train_file_dir}/**/*.jsonl', recursive=True) logger.info(f"train files: {', '.join(train_data_files)}") data_files["train"] = train_data_files if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir): eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob( f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True) logger.info(f"eval files: {', '.join(eval_data_files)}") data_files["validation"] = eval_data_files raw_datasets = load_dataset( 'json', data_files=data_files, cache_dir=model_args.cache_dir, ) # If no validation data is there, validation_split_percentage will be used to divide the dataset. if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( 'json', data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) raw_datasets["train"] = load_dataset( 'json', data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) logger.info(f"Raw datasets: {raw_datasets}") # Preprocessing the datasets full_max_length = data_args.max_source_length + data_args.max_target_length def preprocess_reward_function(examples): """ Turn the dataset into pairs of Question + Answer, where input_ids_chosen is the preferred question + answer and text_rejected is the other. """ new_examples = { "input_ids_chosen": [], "attention_mask_chosen": [], "input_ids_rejected": [], "attention_mask_rejected": [], } for question, chosen, rejected in zip(examples["question"], examples["response_chosen"], examples["response_rejected"]): tokenized_chosen = tokenizer("Question: " + question + "\n\nAnswer: " + chosen) tokenized_rejected = tokenizer("Question: " + question + "\n\nAnswer: " + rejected) new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) return new_examples train_dataset = None max_train_samples = 0 if training_args.do_train: if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets['train'] max_train_samples = len(train_dataset) if data_args.max_train_samples is not None and data_args.max_train_samples > 0: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) logger.debug(f"Example train_dataset[0]: {train_dataset[0]}") with training_args.main_process_first(desc="Train dataset tokenization"): tokenized_dataset = train_dataset.shuffle().map( preprocess_reward_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=train_dataset.column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on dataset", ) train_dataset = tokenized_dataset.filter( lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len( x['input_ids_chosen']) <= full_max_length ) logger.debug(f"Num train_samples: {len(train_dataset)}") logger.debug("Tokenized training example:") logger.debug(tokenizer.decode(train_dataset[0]['input_ids_chosen'])) eval_dataset = None max_eval_samples = 0 if training_args.do_eval: with training_args.main_process_first(desc="Eval dataset tokenization"): if "validation" not in raw_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = raw_datasets["validation"] max_eval_samples = len(eval_dataset) if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}") tokenized_dataset = eval_dataset.map( preprocess_reward_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=eval_dataset.column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on dataset", ) eval_dataset = tokenized_dataset.filter( lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len( x['input_ids_chosen']) <= full_max_length ) logger.debug(f"Num eval_samples: {len(eval_dataset)}") logger.debug("Tokenized eval example:") logger.debug(tokenizer.decode(eval_dataset[0]['input_ids_chosen'])) # Initialize our Trainer if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() model.config.use_cache = False else: model.config.use_cache = True model.enable_input_require_grads() if torch.cuda.device_count() > 1: # Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available model.is_parallelizable = True model.model_parallel = True trainer = RewardTrainer( model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, compute_metrics=compute_metrics, data_collator=RewardDataCollatorWithPadding( tokenizer=tokenizer, max_length=full_max_length, padding="max_length" ), ) # Training if training_args.do_train: logger.info("*** Train ***") logger.debug(f"Train dataloader example: {next(iter(trainer.get_train_dataloader()))}") checkpoint = None if training_args.resume_from_checkpoint is not None: checkpoint = training_args.resume_from_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics metrics["train_samples"] = max_train_samples logger.debug(f"Training metrics: {metrics}") trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() logger.info(f"Saving model checkpoint to {training_args.output_dir}") save_model(training_args.output_dir, model, tokenizer, training_args) # Evaluation if training_args.do_eval and trainer.is_world_process_zero(): logger.info("*** Evaluate ***") metrics = trainer.evaluate() metrics["eval_samples"] = max_eval_samples try: perplexity = math.exp(metrics["eval_loss"]) except OverflowError: perplexity = float("inf") metrics["perplexity"] = perplexity logger.debug(f"Eval metrics: {metrics}") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) if __name__ == "__main__": main()