# -*- coding: utf-8 -*- # Copyright 2023 XuMing(xuming624@qq.com) and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. part of this code is adapted from https://github.com/shibing624/textgen """ import math import os from dataclasses import dataclass, field from glob import glob from typing import List, Optional, Dict, Sequence import torch from datasets import load_dataset from loguru import logger from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training from transformers import ( AutoConfig, BloomForCausalLM, AutoModel, AutoModelForCausalLM, LlamaTokenizer, LlamaForCausalLM, BloomTokenizerFast, AutoTokenizer, HfArgumentParser, Trainer, TrainingArguments, set_seed, BitsAndBytesConfig, DataCollatorForSeq2Seq, ) from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer_pt_utils import LabelSmoother MODEL_CLASSES = { "bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast), "chatglm": (AutoConfig, AutoModel, AutoTokenizer), "llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer), "baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer), "auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer), } @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. """ model_type: str = field( default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())} ) model_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." ) }, ) tokenizer_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The tokenizer for weights initialization.Don't set if you want to train a model from scratch." ) }, ) load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."}) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) use_fast_tokenizer: bool = field( default=False, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, ) torch_dtype: Optional[str] = field( default="float16", metadata={ "help": ( "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights." ), "choices": ["auto", "bfloat16", "float16", "float32"], }, ) device_map: Optional[str] = field( default="auto", metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "}, ) trust_remote_code: bool = field( default=True, metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."}, ) def __post_init__(self): if self.model_type is None: raise ValueError( "You must specify a valid model_type to run training. Available model types are " + ", ".join( MODEL_CLASSES.keys())) if self.model_name_or_path is None: raise ValueError("You must specify a valid model_name_or_path to run training.") @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) train_file_dir: Optional[str] = field(default=None, metadata={"help": "The train jsonl data file folder."}) validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}) template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."}) max_train_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ) }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." ) }, ) max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"}) max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"}) ignore_pad_token_for_loss: bool = field( default=True, metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."}, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) validation_split_percentage: Optional[int] = field( default=1, metadata={ "help": "The percentage of the train set used as validation set in case there's no validation split" }, ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) def __post_init__(self): if self.max_train_samples is not None and 0 < self.max_train_samples <= 1000: logger.warning("You may set max_train_samples = -1 to run all samples in production.") if self.max_source_length < 30: raise ValueError("You must specify a valid max_source_length >= 30 to run training.") @dataclass class PeftArguments(TrainingArguments): use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"}) target_modules: Optional[str] = field(default="all") lora_rank: Optional[int] = field(default=8) lora_dropout: Optional[float] = field(default=0.05) lora_alpha: Optional[float] = field(default=32.0) modules_to_save: Optional[str] = field(default=None) peft_path: Optional[str] = field(default=None, metadata={"help": "The path to the peft model"}) qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"}) class CastOutputToFloat(torch.nn.Sequential): """Cast the output of the model to float""" def forward(self, x): return super().forward(x).to(torch.float32) @dataclass class Conversation: """A class that manages prompt templates and keeps all conversation history.""" # The name of this template name: str # The system prompt system_prompt: str # All messages. format: list of [question, answer] messages: Optional[List[Sequence[str]]] # The roles of the speakers roles: Optional[Sequence[str]] # Conversation prompt prompt: str # Separator sep: str # Stop token, default is tokenizer.eos_token stop_str: Optional[str] = "" def get_prompt( self, messages: Optional[List[Sequence[str]]] = None, system_prompt: Optional[str] = "" ) -> str: """ Returns a string containing prompt without response. """ return "".join(self._format_example(messages, system_prompt)) def get_dialog( self, messages: Optional[List[Sequence[str]]] = None, system_prompt: Optional[str] = "" ) -> List[str]: """ Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response. """ return self._format_example(messages, system_prompt) def _format_example( self, messages: Optional[List[Sequence[str]]] = None, system_prompt: Optional[str] = "" ) -> List[str]: system_prompt = system_prompt or self.system_prompt system_prompt = system_prompt + self.sep if system_prompt else "" # add separator for non-empty system prompt messages = messages or self.messages convs = [] for turn_idx, [user_query, bot_resp] in enumerate(messages): if turn_idx == 0: convs.append(system_prompt + self.prompt.format(query=user_query)) convs.append(bot_resp) else: convs.append(self.sep + self.prompt.format(query=user_query)) convs.append(bot_resp) return convs def append_message(self, query: str, answer: str): """Append a new message.""" self.messages.append([query, answer]) # A global registry for all conversation templates conv_templates: Dict[str, Conversation] = {} def register_conv_template(template: Conversation): """Register a new conversation template.""" conv_templates[template.name] = template """Vicuna v1.1 template Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 """ register_conv_template( Conversation( name="vicuna", system_prompt="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", messages=[], roles=("USER", "ASSISTANT"), prompt="USER: {query} ASSISTANT: ", sep="", ) ) """Alpaca template""" register_conv_template( Conversation( name="alpaca", system_prompt="Below is an instruction that describes a task. " "Write a response that appropriately completes the request.", messages=[], roles=("### Instruction", "### Response"), prompt="### Instruction:\n{query}\n\n### Response:\n", sep="\n\n", ) ) """Baichuan-13B-Chat template source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507 Support: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat """ register_conv_template( Conversation( name="baichuan-chat", system_prompt="", messages=[], roles=("", ""), prompt=" {query} ", sep="", ) ) """ziya template""" register_conv_template( Conversation( name="ziya", system_prompt="", messages=[], roles=("", ""), prompt=":{query}\n:", sep="\n", ) ) """Linly template""" register_conv_template( Conversation( name="linly", system_prompt="", messages=[], roles=("User", "Bot"), prompt="User: {query}\nBot: ", sep="\n", ) ) """ChatGLM1 template source: https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1307 """ register_conv_template( Conversation( name="chatglm", system_prompt="", messages=[], roles=("问", "答"), prompt="问:{query}\n答:", sep="\n", ) ) """ChatGLM2 template source: https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1007 """ register_conv_template( # source: Conversation( name="chatglm2", system_prompt="", messages=[], roles=("问", "答"), prompt="问:{query}\n\n答:", sep="\n\n", ) ) """Phoenix template""" register_conv_template( Conversation( name="phoenix", system_prompt="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", messages=[], roles=("Human", "Assistant"), prompt="Human: {query}Assistant: ", sep="", ) ) """belle template Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B """ register_conv_template( Conversation( name="belle", system_prompt="", messages=[], roles=("Human", "Belle"), prompt="Human: {query}\n\nBelle: ", sep="\n\n", ) ) """aquila template Supports: https://huggingface.co/qhduan/aquilachat-7b """ register_conv_template( Conversation( name="aquila", system_prompt="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", messages=[], roles=("Human", "Assistant"), prompt="Human: {query}###Assistant: ", sep="###", ) ) """intern template Supports: https://huggingface.co/internlm/internlm-chat-7b """ register_conv_template( Conversation( name="intern", system_prompt="", messages=[], roles=("<|User|>", "<|Bot|>"), prompt="<|User|>:{query}\n<|Bot|>:", sep="\n", stop_str="", ) ) """StarChat template""" register_conv_template( Conversation( name="starchat", system_prompt="\n", messages=[], roles=("<|user|>", "<|assistant|>"), prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n", sep="<|end|>\n", stop_str="<|end|>", ) ) """llama2 template reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 """ register_conv_template( Conversation( name="llama2", system_prompt="<>\nYou are a helpful, respectful and honest assistant. " "Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, racist, sexist, " "toxic, dangerous, or illegal content. " "Please ensure that your responses are socially unbiased and positive in nature.\n\n" "If a question does not make any sense, or is not factually coherent, " "explain why instead of answering something not correct. " "If you don't know the answer to a question, please don't share false information.\n<>\n\n", messages=[], roles=("[INST]", "[/INST]"), prompt=" [INST] {query} [/INST] ", sep="", ) ) """llama2-zh template Sources: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b """ register_conv_template( Conversation( name="llama2-zh", system_prompt="<>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<>\n\n", messages=[], roles=("[INST]", "[/INST]"), prompt=" [INST] {query} [/INST] ", sep="", ) ) """XVERSE template Supports: https://huggingface.co/xverse/XVERSE-13B-Chat """ register_conv_template( Conversation( name="xverse", system_prompt="", messages=[], roles=("Human", "Assistant"), prompt="Human: {query}\n\nAssistant: ", sep="", ) ) """Qwen template Supports: https://huggingface.co/Qwen/Qwen-7B-Chat chatml: https://xbot123.com/645a461b922f176d7cfdbc2d/ """ register_conv_template( Conversation( name="chatml", system_prompt="You are a helpful assistant.", messages=[], roles=("user", "assistant"), prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n", sep="<|im_end|>\n", stop_str="<|im_end|>", ) ) def get_conv_template(name: str) -> Conversation: """Get a conversation template.""" return conv_templates[name] class SavePeftModelTrainer(Trainer): """ Trainer for lora models """ def save_model(self, output_dir=None, _internal_call=False): """Save the LoRA model.""" os.makedirs(output_dir, exist_ok=True) if self.args.local_rank in [-1, 0]: torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) self.model.save_pretrained(output_dir) def save_model(output_dir, model, tokenizer, args): """Save the model and the tokenizer.""" os.makedirs(output_dir, exist_ok=True) # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, "module") else model if args.local_rank in [-1, 0]: model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME)) def print_trainable_parameters(model): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" ) def find_all_linear_names(peft_model, int4=False, int8=False): """Find all linear layer names in the model. reference from qlora paper.""" cls = torch.nn.Linear if int4 or int8: import bitsandbytes as bnb if int4: cls = bnb.nn.Linear4bit elif int8: cls = bnb.nn.Linear8bitLt lora_module_names = set() for name, module in peft_model.named_modules(): if isinstance(module, cls): # last layer is not add to lora_module_names if 'lm_head' in name: continue if 'output_layer' in name: continue names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) return sorted(lora_module_names) def main(): parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() logger.info(f"Model args: {model_args}") logger.info(f"Data args: {data_args}") logger.info(f"Training args: {training_args}") logger.info( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set seed before initializing model. set_seed(training_args.seed) if not model_args.model_type: raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.") config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type] # Load tokenizer tokenizer_kwargs = { "cache_dir": model_args.cache_dir, "use_fast": model_args.use_fast_tokenizer, "trust_remote_code": model_args.trust_remote_code, } tokenizer_name_or_path = model_args.tokenizer_name_or_path if not tokenizer_name_or_path: tokenizer_name_or_path = model_args.model_name_or_path tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs) prompt_template = get_conv_template(data_args.template_name) if tokenizer.eos_token_id is None: tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT logger.info("Add eos token: {}".format(tokenizer.eos_token)) if tokenizer.pad_token_id is None: if tokenizer.unk_token_id is not None: tokenizer.pad_token = tokenizer.unk_token else: tokenizer.pad_token = tokenizer.eos_token logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.debug(f"Tokenizer: {tokenizer}") IGNORE_INDEX = LabelSmoother.ignore_index if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id # Get datasets if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, ) if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) raw_datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: # Loading a dataset from local files. data_files = {} if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir): train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob( f'{data_args.train_file_dir}/**/*.jsonl', recursive=True) logger.info(f"train files: {train_data_files}") data_files["train"] = train_data_files if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir): eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob( f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True) logger.info(f"eval files: {eval_data_files}") data_files["validation"] = eval_data_files raw_datasets = load_dataset( 'json', data_files=data_files, cache_dir=model_args.cache_dir, ) # If no validation data is there, validation_split_percentage will be used to divide the dataset. if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( 'json', data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) raw_datasets["train"] = load_dataset( 'json', data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) logger.info(f"Raw datasets: {raw_datasets}") # Preprocessing the datasets max_source_length = data_args.max_source_length max_target_length = data_args.max_target_length max_length = max_source_length + max_target_length def preprocess_function(examples): """ Preprocessing the datasets. part of code modified from https://github.com/lm-sys/FastChat """ input_ids_list = [] targets_list = [] roles = ["human", "gpt"] def get_dialog(examples): for i, source in enumerate(examples['conversations']): if len(source) < 2: continue data_role = source[0].get("from", "") if data_role not in roles or data_role != roles[0]: # Skip the first one if it is not from human source = source[1:] if len(source) < 2: continue messages = [] for j, sentence in enumerate(source): data_role = sentence.get("from", "") if data_role not in roles: logger.warning(f"unknown role: {data_role}, {i}. (ignored)") break if data_role == roles[j % 2]: messages.append(sentence["value"]) if len(messages) < 2 or len(messages) % 2 != 0: continue # Convert the list to pairs of elements history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)] yield prompt_template.get_dialog(history_messages) for dialog in get_dialog(examples): input_ids, labels = [], [] for i in range(len(dialog) // 2): source_ids = tokenizer.encode(text=dialog[2 * i], add_special_tokens=(i == 0)) target_ids = tokenizer.encode(text=dialog[2 * i + 1], add_special_tokens=False) if len(source_ids) > max_source_length: source_ids = source_ids[:max_source_length] if len(target_ids) > max_target_length - 1: # eos token target_ids = target_ids[:max_target_length - 1] if len(source_ids) > 0 and source_ids[0] == tokenizer.eos_token_id: source_ids = source_ids[1:] if len(target_ids) > 0 and target_ids[-1] == tokenizer.eos_token_id: target_ids = target_ids[:-1] if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: break input_ids += source_ids + target_ids + [tokenizer.eos_token_id] # add eos token for each turn labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] input_ids_list.append(input_ids) targets_list.append(labels) return dict( input_ids=input_ids_list, labels=targets_list, ) def filter_empty_labels(example): """Remove empty labels dataset.""" return not all(label == IGNORE_INDEX for label in example["labels"]) train_dataset = None max_train_samples = 0 if training_args.do_train: if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets['train'] max_train_samples = len(train_dataset) if data_args.max_train_samples is not None and data_args.max_train_samples > 0: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) logger.debug(f"Example train_dataset[0]: {train_dataset[0]}") with training_args.main_process_first(desc="Train dataset tokenization"): train_dataset = train_dataset.shuffle().map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=train_dataset.column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on dataset", ) train_dataset = train_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers) logger.debug(f"Num train_samples: {len(train_dataset)}") logger.debug("Tokenized training example:") logger.debug(f"Decode input_ids[0]: {tokenizer.decode(train_dataset[0]['input_ids'])}") replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id for label in list(train_dataset[0]['labels'])] logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}") eval_dataset = None max_eval_samples = 0 if training_args.do_eval: with training_args.main_process_first(desc="Eval dataset tokenization"): if "validation" not in raw_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = raw_datasets["validation"] max_eval_samples = len(eval_dataset) if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}") eval_dataset = eval_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=eval_dataset.column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on dataset", ) eval_dataset = eval_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers) logger.debug(f"Num eval_samples: {len(eval_dataset)}") logger.debug("Tokenized eval example:") logger.debug(tokenizer.decode(eval_dataset[0]['input_ids'])) # Load model if model_args.model_name_or_path: torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 if ddp: model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0} if training_args.qlora and (len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled()): logger.warning("FSDP and ZeRO3 are both currently incompatible with QLoRA.") config = config_class.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, torch_dtype=torch_dtype, cache_dir=model_args.cache_dir ) model = model_class.from_pretrained( model_args.model_name_or_path, config=config, load_in_8bit=model_args.load_in_8bit, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), device_map=model_args.device_map, trust_remote_code=model_args.trust_remote_code, quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, ) if training_args.qlora else None, ) if hasattr(model, 'lm_head'): model.lm_head = CastOutputToFloat(model.lm_head) else: raise ValueError(f"Error, model_name_or_path is None, SFT must be loaded from a pre-trained model") if training_args.use_peft: logger.info("Fine-tuning method: LoRA(PEFT)") if training_args.peft_path is not None: logger.info(f"Peft from pre-trained model: {training_args.peft_path}") model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True) else: target_modules = training_args.target_modules.split(',') if training_args.target_modules else None if target_modules and 'all' in target_modules: target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit) modules_to_save = training_args.modules_to_save if modules_to_save is not None: modules_to_save = modules_to_save.split(',') logger.info(f"Peft target_modules: {target_modules}") logger.info(f"Peft lora_rank: {training_args.lora_rank}") peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=target_modules, inference_mode=False, r=training_args.lora_rank, lora_alpha=training_args.lora_alpha, lora_dropout=training_args.lora_dropout, modules_to_save=modules_to_save) model = get_peft_model(model, peft_config) if model_args.load_in_8bit: model = prepare_model_for_int8_training(model) model.print_trainable_parameters() else: logger.info("Fine-tuning method: Full parameters training") model = model.float() print_trainable_parameters(model) logger.debug(f"Model: {model}") # Initialize our Trainer if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() model.config.use_cache = False else: model.config.use_cache = True model.enable_input_require_grads() if not ddp and torch.cuda.device_count() > 1: # Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available model.is_parallelizable = True model.model_parallel = True data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) # Initialize our Trainer trainer = SavePeftModelTrainer( model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, ) # Training if training_args.do_train: logger.info("*** Train ***") sample = next(iter(trainer.get_train_dataloader())) logger.debug(f"Train dataloader example: {sample}") logger.debug(f"Detail input_ids: {list(sample['input_ids'])[:3]}, \nlabels: {list(sample['labels'])[:3]}") logger.debug(f"Decode input_ids[0]: {tokenizer.decode(sample['input_ids'][0])}") replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id for label in sample['labels'][0]] logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}") checkpoint = None if training_args.resume_from_checkpoint is not None: checkpoint = training_args.resume_from_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics metrics["train_samples"] = max_train_samples logger.debug(f"Training metrics: {metrics}") trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) model.config.use_cache = True # enable cache after training trainer.save_state() logger.info(f"Saving model checkpoint to {training_args.output_dir}") save_model(training_args.output_dir, model, tokenizer, training_args) # Evaluation if training_args.do_eval and trainer.is_world_process_zero(): logger.info("*** Evaluate ***") metrics = trainer.evaluate() metrics["eval_samples"] = max_eval_samples try: perplexity = math.exp(metrics["eval_loss"]) except OverflowError: perplexity = float("inf") metrics["perplexity"] = perplexity logger.debug(f"Eval metrics: {metrics}") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) if __name__ == "__main__": main()