MedicalGPT-main / supervised_finetuning.py
nengrenjie83's picture
Upload 28 files
b78b52f
raw
history blame contribute delete
No virus
36.2 kB
# -*- coding: utf-8 -*-
# Copyright 2023 XuMing([email protected]) and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
part of this code is adapted from https://github.com/shibing624/textgen
"""
import math
import os
from dataclasses import dataclass, field
from glob import glob
from typing import List, Optional, Dict, Sequence
import torch
from datasets import load_dataset
from loguru import logger
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
from transformers import (
AutoConfig,
BloomForCausalLM,
AutoModel,
AutoModelForCausalLM,
LlamaTokenizer,
LlamaForCausalLM,
BloomTokenizerFast,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
BitsAndBytesConfig,
DataCollatorForSeq2Seq,
)
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_pt_utils import LabelSmoother
MODEL_CLASSES = {
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
}
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_type: str = field(
default=None,
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
)
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
torch_dtype: Optional[str] = field(
default="float16",
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
device_map: Optional[str] = field(
default="auto",
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
)
trust_remote_code: bool = field(
default=True,
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
)
def __post_init__(self):
if self.model_type is None:
raise ValueError(
"You must specify a valid model_type to run training. Available model types are " + ", ".join(
MODEL_CLASSES.keys()))
if self.model_name_or_path is None:
raise ValueError("You must specify a valid model_name_or_path to run training.")
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The train jsonl data file folder."})
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."})
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."})
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=1,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
def __post_init__(self):
if self.max_train_samples is not None and 0 < self.max_train_samples <= 1000:
logger.warning("You may set max_train_samples = -1 to run all samples in production.")
if self.max_source_length < 30:
raise ValueError("You must specify a valid max_source_length >= 30 to run training.")
@dataclass
class PeftArguments(TrainingArguments):
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
target_modules: Optional[str] = field(default="all")
lora_rank: Optional[int] = field(default=8)
lora_dropout: Optional[float] = field(default=0.05)
lora_alpha: Optional[float] = field(default=32.0)
modules_to_save: Optional[str] = field(default=None)
peft_path: Optional[str] = field(default=None, metadata={"help": "The path to the peft model"})
qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})
class CastOutputToFloat(torch.nn.Sequential):
"""Cast the output of the model to float"""
def forward(self, x):
return super().forward(x).to(torch.float32)
@dataclass
class Conversation:
"""A class that manages prompt templates and keeps all conversation history."""
# The name of this template
name: str
# The system prompt
system_prompt: str
# All messages. format: list of [question, answer]
messages: Optional[List[Sequence[str]]]
# The roles of the speakers
roles: Optional[Sequence[str]]
# Conversation prompt
prompt: str
# Separator
sep: str
# Stop token, default is tokenizer.eos_token
stop_str: Optional[str] = "</s>"
def get_prompt(
self,
messages: Optional[List[Sequence[str]]] = None,
system_prompt: Optional[str] = ""
) -> str:
"""
Returns a string containing prompt without response.
"""
return "".join(self._format_example(messages, system_prompt))
def get_dialog(
self,
messages: Optional[List[Sequence[str]]] = None,
system_prompt: Optional[str] = ""
) -> List[str]:
"""
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
"""
return self._format_example(messages, system_prompt)
def _format_example(
self,
messages: Optional[List[Sequence[str]]] = None,
system_prompt: Optional[str] = ""
) -> List[str]:
system_prompt = system_prompt or self.system_prompt
system_prompt = system_prompt + self.sep if system_prompt else "" # add separator for non-empty system prompt
messages = messages or self.messages
convs = []
for turn_idx, [user_query, bot_resp] in enumerate(messages):
if turn_idx == 0:
convs.append(system_prompt + self.prompt.format(query=user_query))
convs.append(bot_resp)
else:
convs.append(self.sep + self.prompt.format(query=user_query))
convs.append(bot_resp)
return convs
def append_message(self, query: str, answer: str):
"""Append a new message."""
self.messages.append([query, answer])
# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}
def register_conv_template(template: Conversation):
"""Register a new conversation template."""
conv_templates[template.name] = template
"""Vicuna v1.1 template
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
"""
register_conv_template(
Conversation(
name="vicuna",
system_prompt="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
messages=[],
roles=("USER", "ASSISTANT"),
prompt="USER: {query} ASSISTANT: ",
sep="</s>",
)
)
"""Alpaca template"""
register_conv_template(
Conversation(
name="alpaca",
system_prompt="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
messages=[],
roles=("### Instruction", "### Response"),
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
)
)
"""Baichuan-13B-Chat template
source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
Support: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_conv_template(
Conversation(
name="baichuan-chat",
system_prompt="",
messages=[],
roles=("<reserved_102>", "<reserved_103>"),
prompt=" <reserved_102> {query} <reserved_103> ",
sep="</s>",
)
)
"""ziya template"""
register_conv_template(
Conversation(
name="ziya",
system_prompt="",
messages=[],
roles=("<human>", "<bot>"),
prompt="<human>:{query}\n<bot>:",
sep="\n",
)
)
"""Linly template"""
register_conv_template(
Conversation(
name="linly",
system_prompt="",
messages=[],
roles=("User", "Bot"),
prompt="User: {query}\nBot: ",
sep="\n",
)
)
"""ChatGLM1 template
source: https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1307
"""
register_conv_template(
Conversation(
name="chatglm",
system_prompt="",
messages=[],
roles=("问", "答"),
prompt="问:{query}\n答:",
sep="\n",
)
)
"""ChatGLM2 template
source: https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1007
"""
register_conv_template(
# source:
Conversation(
name="chatglm2",
system_prompt="",
messages=[],
roles=("问", "答"),
prompt="问:{query}\n\n答:",
sep="\n\n",
)
)
"""Phoenix template"""
register_conv_template(
Conversation(
name="phoenix",
system_prompt="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
messages=[],
roles=("Human", "Assistant"),
prompt="Human: <s>{query}</s>Assistant: ",
sep="</s>",
)
)
"""belle template
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_conv_template(
Conversation(
name="belle",
system_prompt="",
messages=[],
roles=("Human", "Belle"),
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
)
)
"""aquila template
Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
register_conv_template(
Conversation(
name="aquila",
system_prompt="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
messages=[],
roles=("Human", "Assistant"),
prompt="Human: {query}###Assistant: ",
sep="###",
)
)
"""intern template
Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_conv_template(
Conversation(
name="intern",
system_prompt="",
messages=[],
roles=("<|User|>", "<|Bot|>"),
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
stop_str="<eoa>",
)
)
"""StarChat template"""
register_conv_template(
Conversation(
name="starchat",
system_prompt="<system>\n",
messages=[],
roles=("<|user|>", "<|assistant|>"),
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
sep="<|end|>\n",
stop_str="<|end|>",
)
)
"""llama2 template
reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
"""
register_conv_template(
Conversation(
name="llama2",
system_prompt="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, racist, sexist, "
"toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
messages=[],
roles=("[INST]", "[/INST]"),
prompt=" [INST] {query} [/INST] ",
sep="</s>",
)
)
"""llama2-zh template
Sources: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
"""
register_conv_template(
Conversation(
name="llama2-zh",
system_prompt="<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n",
messages=[],
roles=("[INST]", "[/INST]"),
prompt=" [INST] {query} [/INST] ",
sep="</s>",
)
)
"""XVERSE template
Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
"""
register_conv_template(
Conversation(
name="xverse",
system_prompt="",
messages=[],
roles=("Human", "Assistant"),
prompt="Human: {query}\n\nAssistant: ",
sep="</s>",
)
)
"""Qwen template
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
chatml: https://xbot123.com/645a461b922f176d7cfdbc2d/
"""
register_conv_template(
Conversation(
name="chatml",
system_prompt="You are a helpful assistant.",
messages=[],
roles=("user", "assistant"),
prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n",
sep="<|im_end|>\n",
stop_str="<|im_end|>",
)
)
def get_conv_template(name: str) -> Conversation:
"""Get a conversation template."""
return conv_templates[name]
class SavePeftModelTrainer(Trainer):
"""
Trainer for lora models
"""
def save_model(self, output_dir=None, _internal_call=False):
"""Save the LoRA model."""
os.makedirs(output_dir, exist_ok=True)
if self.args.local_rank in [-1, 0]:
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
self.model.save_pretrained(output_dir)
def save_model(output_dir, model, tokenizer, args):
"""Save the model and the tokenizer."""
os.makedirs(output_dir, exist_ok=True)
# Take care of distributed/parallel training
model_to_save = model.module if hasattr(model, "module") else model
if args.local_rank in [-1, 0]:
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def find_all_linear_names(peft_model, int4=False, int8=False):
"""Find all linear layer names in the model. reference from qlora paper."""
cls = torch.nn.Linear
if int4 or int8:
import bitsandbytes as bnb
if int4:
cls = bnb.nn.Linear4bit
elif int8:
cls = bnb.nn.Linear8bitLt
lora_module_names = set()
for name, module in peft_model.named_modules():
if isinstance(module, cls):
# last layer is not add to lora_module_names
if 'lm_head' in name:
continue
if 'output_layer' in name:
continue
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
return sorted(lora_module_names)
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logger.info(f"Model args: {model_args}")
logger.info(f"Data args: {data_args}")
logger.info(f"Training args: {training_args}")
logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set seed before initializing model.
set_seed(training_args.seed)
if not model_args.model_type:
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
# Load tokenizer
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"trust_remote_code": model_args.trust_remote_code,
}
tokenizer_name_or_path = model_args.tokenizer_name_or_path
if not tokenizer_name_or_path:
tokenizer_name_or_path = model_args.model_name_or_path
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
prompt_template = get_conv_template(data_args.template_name)
if tokenizer.eos_token_id is None:
tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
logger.debug(f"Tokenizer: {tokenizer}")
IGNORE_INDEX = LabelSmoother.ignore_index if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
# Get datasets
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
raw_datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
else:
# Loading a dataset from local files.
data_files = {}
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"train files: {train_data_files}")
data_files["train"] = train_data_files
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob(
f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True)
logger.info(f"eval files: {eval_data_files}")
data_files["validation"] = eval_data_files
raw_datasets = load_dataset(
'json',
data_files=data_files,
cache_dir=model_args.cache_dir,
)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
'json',
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
raw_datasets["train"] = load_dataset(
'json',
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
logger.info(f"Raw datasets: {raw_datasets}")
# Preprocessing the datasets
max_source_length = data_args.max_source_length
max_target_length = data_args.max_target_length
max_length = max_source_length + max_target_length
def preprocess_function(examples):
"""
Preprocessing the datasets.
part of code modified from https://github.com/lm-sys/FastChat
"""
input_ids_list = []
targets_list = []
roles = ["human", "gpt"]
def get_dialog(examples):
for i, source in enumerate(examples['conversations']):
if len(source) < 2:
continue
data_role = source[0].get("from", "")
if data_role not in roles or data_role != roles[0]:
# Skip the first one if it is not from human
source = source[1:]
if len(source) < 2:
continue
messages = []
for j, sentence in enumerate(source):
data_role = sentence.get("from", "")
if data_role not in roles:
logger.warning(f"unknown role: {data_role}, {i}. (ignored)")
break
if data_role == roles[j % 2]:
messages.append(sentence["value"])
if len(messages) < 2 or len(messages) % 2 != 0:
continue
# Convert the list to pairs of elements
history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)]
yield prompt_template.get_dialog(history_messages)
for dialog in get_dialog(examples):
input_ids, labels = [], []
for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2 * i], add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=dialog[2 * i + 1], add_special_tokens=False)
if len(source_ids) > max_source_length:
source_ids = source_ids[:max_source_length]
if len(target_ids) > max_target_length - 1: # eos token
target_ids = target_ids[:max_target_length - 1]
if len(source_ids) > 0 and source_ids[0] == tokenizer.eos_token_id:
source_ids = source_ids[1:]
if len(target_ids) > 0 and target_ids[-1] == tokenizer.eos_token_id:
target_ids = target_ids[:-1]
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
break
input_ids += source_ids + target_ids + [tokenizer.eos_token_id] # add eos token for each turn
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
input_ids_list.append(input_ids)
targets_list.append(labels)
return dict(
input_ids=input_ids_list,
labels=targets_list,
)
def filter_empty_labels(example):
"""Remove empty labels dataset."""
return not all(label == IGNORE_INDEX for label in example["labels"])
train_dataset = None
max_train_samples = 0
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets['train']
max_train_samples = len(train_dataset)
if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
with training_args.main_process_first(desc="Train dataset tokenization"):
train_dataset = train_dataset.shuffle().map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=train_dataset.column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
train_dataset = train_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers)
logger.debug(f"Num train_samples: {len(train_dataset)}")
logger.debug("Tokenized training example:")
logger.debug(f"Decode input_ids[0]: {tokenizer.decode(train_dataset[0]['input_ids'])}")
replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id
for label in list(train_dataset[0]['labels'])]
logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
eval_dataset = None
max_eval_samples = 0
if training_args.do_eval:
with training_args.main_process_first(desc="Eval dataset tokenization"):
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
max_eval_samples = len(eval_dataset)
if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=eval_dataset.column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
eval_dataset = eval_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers)
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
logger.debug("Tokenized eval example:")
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))
# Load model
if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
if training_args.qlora and (len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled()):
logger.warning("FSDP and ZeRO3 are both currently incompatible with QLoRA.")
config = config_class.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch_dtype,
cache_dir=model_args.cache_dir
)
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
load_in_8bit=model_args.load_in_8bit,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
device_map=model_args.device_map,
trust_remote_code=model_args.trust_remote_code,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
) if training_args.qlora else None,
)
if hasattr(model, 'lm_head'):
model.lm_head = CastOutputToFloat(model.lm_head)
else:
raise ValueError(f"Error, model_name_or_path is None, SFT must be loaded from a pre-trained model")
if training_args.use_peft:
logger.info("Fine-tuning method: LoRA(PEFT)")
if training_args.peft_path is not None:
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
else:
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
if target_modules and 'all' in target_modules:
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
modules_to_save = training_args.modules_to_save
if modules_to_save is not None:
modules_to_save = modules_to_save.split(',')
logger.info(f"Peft target_modules: {target_modules}")
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=training_args.lora_rank,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
modules_to_save=modules_to_save)
model = get_peft_model(model, peft_config)
if model_args.load_in_8bit:
model = prepare_model_for_int8_training(model)
model.print_trainable_parameters()
else:
logger.info("Fine-tuning method: Full parameters training")
model = model.float()
print_trainable_parameters(model)
logger.debug(f"Model: {model}")
# Initialize our Trainer
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
model.config.use_cache = False
else:
model.config.use_cache = True
model.enable_input_require_grads()
if not ddp and torch.cuda.device_count() > 1:
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
# Initialize our Trainer
trainer = SavePeftModelTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
# Training
if training_args.do_train:
logger.info("*** Train ***")
sample = next(iter(trainer.get_train_dataloader()))
logger.debug(f"Train dataloader example: {sample}")
logger.debug(f"Detail input_ids: {list(sample['input_ids'])[:3]}, \nlabels: {list(sample['labels'])[:3]}")
logger.debug(f"Decode input_ids[0]: {tokenizer.decode(sample['input_ids'][0])}")
replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id for label in sample['labels'][0]]
logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = max_train_samples
logger.debug(f"Training metrics: {metrics}")
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
model.config.use_cache = True # enable cache after training
trainer.save_state()
logger.info(f"Saving model checkpoint to {training_args.output_dir}")
save_model(training_args.output_dir, model, tokenizer, training_args)
# Evaluation
if training_args.do_eval and trainer.is_world_process_zero():
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = max_eval_samples
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
logger.debug(f"Eval metrics: {metrics}")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if __name__ == "__main__":
main()