|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from collections import OrderedDict |
|
import torch |
|
|
|
model_path = 'model_path' |
|
out_path = 'out_path' |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) |
|
|
|
new_dict = OrderedDict() |
|
for k,v in model.state_dict().items(): |
|
if not 'self_attn.W_pack' in k: |
|
new_dict[k] = v |
|
continue |
|
name_base = k[:k.find('W_pack.weight')] |
|
q,k,v = [v[model.config.hidden_size*i:model.config.hidden_size*(i+1),:] for i in range(3)] |
|
new_dict[name_base + 'q_proj.weight'] = q |
|
new_dict[name_base + 'k_proj.weight'] = k |
|
new_dict[name_base + 'v_proj.weight'] = v |
|
|
|
model.save_pretrained(out_path, state_dict=new_dict) |
|
tokenizer.save_pretrained(out_path) |
|
|