Son_space / app.py
siroleo0's picture
Create app.py
6798aed
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
tokenizer = AutoTokenizer.from_pretrained("af1tang/personaGPT")
model = AutoModelForCausalLM.from_pretrained("af1tang/personaGPT")
if torch.cuda.is_available():
model = model.cuda()
## utility functions ##
flatten = lambda l: [item for sublist in l for item in sublist]
def to_data(x):
if torch.cuda.is_available():
x = x.cpu()
return x.data.numpy()
def to_var(x):
if not torch.is_tensor(x):
x = torch.Tensor(x)
if torch.cuda.is_available():
x = x.cuda()
return x
def display_dialog_history(dialog_hx):
for j, line in enumerate(dialog_hx):
msg = tokenizer.decode(line)
if j %2 == 0:
print(">> User: "+ msg)
else:
print("Bot: "+msg)
print()
def generate_next(bot_input_ids, do_sample=True, top_k=10, top_p=.92,
max_length=1000, pad_token=tokenizer.eos_token_id):
full_msg = model.generate(bot_input_ids, do_sample=True,
top_k=top_k, top_p=top_p,
max_length=max_length, pad_token_id=tokenizer.eos_token_id)
msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:]
return msg
# get personality facts for conversation
personas = []
for i in range(3):
response = input(">> Fact %d: "%(i+1))+ tokenizer.eos_token
personas.append(response)
personas = tokenizer.encode(''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>']))