|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel |
|
from transformers import GPT2TokenizerFast, GPT2Tokenizer |
|
from easyeditor import apply_grace_to_model, GraceHyperParams,nethook |
|
import torch |
|
|
|
|
|
|
|
def edit(prompt, target_new): |
|
request={"prompt":prompt,"target_new":target_new} |
|
hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2-xl.yaml") |
|
|
|
model = AutoModelForCausalLM.from_pretrained("./models/gpt2-xl") |
|
tok = GPT2Tokenizer.from_pretrained("./models/gpt2-xl") |
|
tok.pad_token_id = tok.eos_token_id |
|
global edit_model |
|
edit_model,_ = apply_grace_to_model(model,tok,request,hparams,keep_original_weight=True) |
|
return "finish" |
|
|
|
def generate(input_text): |
|
tok = GPT2Tokenizer.from_pretrained("./models/gpt2-xl") |
|
hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2-xl.yaml") |
|
tok.pad_token_id = tok.eos_token_id |
|
|
|
global edit_model |
|
|
|
input_ids = tok.encode(input_text, return_tensors='pt').to(f'cuda:{hparams.device}') |
|
edit_output = edit_model.generate(input_ids, max_length=30, pad_token_id=tok.eos_token_id) |
|
edit_reply = tok.decode(edit_output[0], skip_special_tokens=True) |
|
del edit_model |
|
torch.cuda.empty_cache() |
|
|
|
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2-xl").to(f'cuda:{hparams.device}') |
|
ori_output = ori_model.generate(input_ids, max_length=30, pad_token_id=tok.eos_token_id) |
|
ori_reply = tok.decode(ori_output[0], skip_special_tokens=True) |
|
|
|
return ori_reply, edit_reply |
|
|