File size: 1,618 Bytes
6798aed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
tokenizer = AutoTokenizer.from_pretrained("af1tang/personaGPT")
model = AutoModelForCausalLM.from_pretrained("af1tang/personaGPT")
if torch.cuda.is_available():
model = model.cuda()
## utility functions ##
flatten = lambda l: [item for sublist in l for item in sublist]
def to_data(x):
if torch.cuda.is_available():
x = x.cpu()
return x.data.numpy()
def to_var(x):
if not torch.is_tensor(x):
x = torch.Tensor(x)
if torch.cuda.is_available():
x = x.cuda()
return x
def display_dialog_history(dialog_hx):
for j, line in enumerate(dialog_hx):
msg = tokenizer.decode(line)
if j %2 == 0:
print(">> User: "+ msg)
else:
print("Bot: "+msg)
print()
def generate_next(bot_input_ids, do_sample=True, top_k=10, top_p=.92,
max_length=1000, pad_token=tokenizer.eos_token_id):
full_msg = model.generate(bot_input_ids, do_sample=True,
top_k=top_k, top_p=top_p,
max_length=max_length, pad_token_id=tokenizer.eos_token_id)
msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:]
return msg
# get personality facts for conversation
personas = []
for i in range(3):
response = input(">> Fact %d: "%(i+1))+ tokenizer.eos_token
personas.append(response)
personas = tokenizer.encode(''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>']))
|