MedicalGPT-main / merge_peft_adapter.py
nengrenjie83's picture
Upload 28 files
b78b52f
raw
history blame contribute delete
No virus
3.88 kB
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
Usage:
python merge_peft_adapter.py \
--base_model_name_or_path path/to/llama/model \
--tokenizer_path path/to/llama/tokenizer \
--peft_model_path path/to/lora/model \
--output_dir path/to/output/dir
after merged, chatglm and baichuan model need copy python script to output dir.
"""
import argparse
import torch
from peft import PeftModel, PeftConfig
from transformers import (
AutoModel,
AutoTokenizer,
BloomForCausalLM,
BloomTokenizerFast,
AutoModelForCausalLM,
LlamaTokenizer,
LlamaForCausalLM,
AutoModelForSequenceClassification,
)
MODEL_CLASSES = {
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoModel, AutoTokenizer),
"llama": (LlamaForCausalLM, LlamaTokenizer),
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoModelForCausalLM, AutoTokenizer),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', default=None, type=str, required=True)
parser.add_argument('--base_model_name_or_path', default=None, required=True, type=str,
help="Base model name or path")
parser.add_argument('--tokenizer_path', default=None, type=str,
help="Please specify tokenization path.")
parser.add_argument('--peft_model_path', default=None, required=True, type=str,
help="Please specify LoRA model to be merged.")
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
parser.add_argument('--output_dir', default='./merged', type=str)
args = parser.parse_args()
print(args)
base_model_path = args.base_model_name_or_path
peft_model_path = args.peft_model_path
output_dir = args.output_dir
print(f"Base model: {base_model_path}")
print(f"LoRA model: {peft_model_path}")
peft_config = PeftConfig.from_pretrained(peft_model_path)
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if peft_config.task_type == "SEQ_CLS":
print("Loading LoRA for sequence classification model")
if args.model_type == "chatglm":
raise ValueError("chatglm does not support sequence classification")
base_model = AutoModelForSequenceClassification.from_pretrained(
base_model_path,
load_in_8bit=False,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
)
else:
print("Loading LoRA for causal language model")
base_model = model_class.from_pretrained(
base_model_path,
load_in_8bit=False,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
)
if args.tokenizer_path:
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
else:
tokenizer = tokenizer_class.from_pretrained(peft_model_path, trust_remote_code=True)
if args.resize_emb:
base_model_token_size = base_model.get_input_embeddings().weight.size(0)
if base_model_token_size != len(tokenizer):
base_model.resize_token_embeddings(len(tokenizer))
print(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}")
lora_model = PeftModel.from_pretrained(
base_model,
peft_model_path,
device_map="auto",
torch_dtype=torch.float16,
)
lora_model.eval()
print(f"Merging with merge_and_unload...")
base_model = lora_model.merge_and_unload()
print("Saving to Hugging Face format...")
tokenizer.save_pretrained(output_dir)
base_model.save_pretrained(output_dir)
print(f"Done! model saved to {output_dir}")
if __name__ == '__main__':
main()