|
from transformers import GPT2Tokenizer, GPT2LMHeadModel |
|
import torch |
|
tokenizer = AutoTokenizer.from_pretrained("af1tang/personaGPT") |
|
model = AutoModelForCausalLM.from_pretrained("af1tang/personaGPT") |
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
|
|
flatten = lambda l: [item for sublist in l for item in sublist] |
|
|
|
def to_data(x): |
|
if torch.cuda.is_available(): |
|
x = x.cpu() |
|
return x.data.numpy() |
|
|
|
def to_var(x): |
|
if not torch.is_tensor(x): |
|
x = torch.Tensor(x) |
|
if torch.cuda.is_available(): |
|
x = x.cuda() |
|
return x |
|
|
|
def display_dialog_history(dialog_hx): |
|
for j, line in enumerate(dialog_hx): |
|
msg = tokenizer.decode(line) |
|
if j %2 == 0: |
|
print(">> User: "+ msg) |
|
else: |
|
print("Bot: "+msg) |
|
print() |
|
|
|
def generate_next(bot_input_ids, do_sample=True, top_k=10, top_p=.92, |
|
max_length=1000, pad_token=tokenizer.eos_token_id): |
|
full_msg = model.generate(bot_input_ids, do_sample=True, |
|
top_k=top_k, top_p=top_p, |
|
max_length=max_length, pad_token_id=tokenizer.eos_token_id) |
|
msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:] |
|
return msg |
|
|
|
|
|
|
|
|
|
|
|
|
|
personas = [] |
|
for i in range(3): |
|
response = input(">> Fact %d: "%(i+1))+ tokenizer.eos_token |
|
personas.append(response) |
|
personas = tokenizer.encode(''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>'])) |
|
|
|
|
|
|
|
|