alKoGolik's picture
Upload 169 files
c87c295 verified
raw
history blame
2.63 kB
from typing import Optional
import os, sys
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
from datetime import datetime
sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from utils.special_tok_llama2 import (
B_CODE,
E_CODE,
B_RESULT,
E_RESULT,
B_INST,
E_INST,
B_SYS,
E_SYS,
DEFAULT_PAD_TOKEN,
DEFAULT_BOS_TOKEN,
DEFAULT_EOS_TOKEN,
DEFAULT_UNK_TOKEN,
IGNORE_INDEX,
)
def create_peft_config(model):
from peft import (
get_peft_model,
LoraConfig,
TaskType,
prepare_model_for_int8_training,
)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
)
# prepare int-8 model for training
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config
def build_model_from_hf_path(
hf_base_model_path: str = "./ckpt/llama-2-13b-chat",
load_peft: Optional[bool] = False,
peft_model_path: Optional[str] = None,
load_in_4bit: bool = True,
):
start_time = datetime.now()
# build tokenizer
tokenizer = LlamaTokenizer.from_pretrained(
hf_base_model_path,
padding_side="right",
use_fast=False,
)
# Handle special tokens
special_tokens_dict = dict()
if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN # 32000
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN # 2
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN # 1
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
tokenizer.add_special_tokens(special_tokens_dict)
tokenizer.add_tokens(
[B_CODE, E_CODE, B_RESULT, E_RESULT, B_INST, E_INST, B_SYS, E_SYS],
special_tokens=True,
)
# build model
model = LlamaForCausalLM.from_pretrained(
hf_base_model_path,
load_in_4bit=load_in_4bit,
device_map="auto",
)
model.resize_token_embeddings(len(tokenizer))
if load_peft and (peft_model_path is not None):
from peft import PeftModel
model = PeftModel.from_pretrained(model, peft_model_path)
end_time = datetime.now()
elapsed_time = end_time - start_time
return {"tokenizer": tokenizer, "model": model}