|
|
|
|
|
|
|
|
|
import logging |
|
import json |
|
import os |
|
import random |
|
from argparse import ArgumentParser |
|
from itertools import chain |
|
from pprint import pformat |
|
import tempfile |
|
import tarfile |
|
import warnings |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import cached_path |
|
|
|
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2LMHeadModel, \ |
|
GPT2Tokenizer |
|
|
|
HF_FINETUNED_MODEL = "https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/gpt_personachat_cache.tar.gz" |
|
SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"] |
|
ATTR_TO_SPECIAL_TOKEN = {'bos_token': '<bos>', 'eos_token': '<eos>', 'pad_token': '<pad>', |
|
'additional_special_tokens': ['<speaker1>', '<speaker2>']} |
|
|
|
|
|
def download_pretrained_model(): |
|
""" Download and extract finetuned model from S3 """ |
|
resolved_archive_file = cached_path(HF_FINETUNED_MODEL) |
|
tempdir = tempfile.mkdtemp() |
|
with tarfile.open(resolved_archive_file, 'r:gz') as archive: |
|
def is_within_directory(directory, target): |
|
|
|
abs_directory = os.path.abspath(directory) |
|
abs_target = os.path.abspath(target) |
|
|
|
prefix = os.path.commonprefix([abs_directory, abs_target]) |
|
|
|
return prefix == abs_directory |
|
|
|
def safe_extract(tar, path=".", members=None, *, numeric_owner=False): |
|
|
|
for member in tar.getmembers(): |
|
member_path = os.path.join(path, member.name) |
|
if not is_within_directory(path, member_path): |
|
raise Exception("Attempted Path Traversal in Tar File") |
|
|
|
tar.extractall(path, members, numeric_owner=numeric_owner) |
|
|
|
|
|
safe_extract(archive, tempdir) |
|
return tempdir |
|
|
|
|
|
def get_dataset(tokenizer, dataset_path, dataset_cache): |
|
""" Get tokenized PERSONACHAT dataset from S3 or cache.""" |
|
dataset_path = dataset_path |
|
dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ |
|
if dataset_cache and os.path.isfile(dataset_cache): |
|
dataset = torch.load(dataset_cache) |
|
else: |
|
personachat_file = cached_path(dataset_path) |
|
with open(personachat_file, "r", encoding="utf-8") as f: |
|
dataset = json.loads(f.read()) |
|
|
|
def tokenize(obj): |
|
if isinstance(obj, str): |
|
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) |
|
if isinstance(obj, dict): |
|
return dict((n, tokenize(o)) for n, o in obj.items()) |
|
return list(tokenize(o) for o in obj) |
|
dataset = tokenize(dataset) |
|
torch.save(dataset, dataset_cache) |
|
return dataset |
|
|
|
|
|
def add_special_tokens_(model, tokenizer): |
|
""" Add special tokens to the tokenizer and the model if they have not already been added. """ |
|
orig_num_tokens = len(tokenizer.encoder) |
|
num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) |
|
if num_added_tokens > 0: |
|
model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens) |
|
|
|
|
|
def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=False, with_eos=True): |
|
""" Build a sequence of input from 3 segments: persona, history and last reply. """ |
|
bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1]) |
|
sequence = [[bos] + list(chain(*persona))] + history + [reply + ([eos] if with_eos else [])] |
|
sequence = [sequence[0]] + [[speaker2 if (len(sequence)-i) % 2 else speaker1] + s for i, s in enumerate(sequence[1:])] |
|
instance = {} |
|
instance["input_ids"] = list(chain(*sequence)) |
|
instance["token_type_ids"] = [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in s] |
|
instance["mc_token_ids"] = len(instance["input_ids"]) - 1 |
|
instance["lm_labels"] = [-100] * len(instance["input_ids"]) |
|
if lm_labels: |
|
instance["lm_labels"] = ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:] |
|
return instance |
|
|
|
|
|
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), |
|
filter_value=-float('Inf')): |
|
""" Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering |
|
Args: |
|
logits: logits distribution shape (vocabulary size) |
|
top_k: <=0: no filtering, >0: keep only top k tokens with highest probability. |
|
top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset |
|
whose total probability mass is greater than or equal to the threshold top_p. |
|
In practice, we select the highest probability tokens whose cumulative probability mass exceeds |
|
the threshold top_p. |
|
threshold: a minimal threshold to keep logits |
|
""" |
|
assert logits.dim() == 1 |
|
top_k = min(top_k, logits.size(-1)) |
|
if top_k > 0: |
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits[indices_to_remove] = filter_value |
|
|
|
if top_p > 0.0: |
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probabilities > top_p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove] |
|
logits[indices_to_remove] = filter_value |
|
|
|
indices_to_remove = logits < threshold |
|
logits[indices_to_remove] = filter_value |
|
|
|
return logits |
|
|
|
|
|
def sample_sequence(personality, history, tokenizer, model, args, current_output=None): |
|
special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) |
|
if current_output is None: |
|
current_output = [] |
|
|
|
for i in range(args.max_length): |
|
instance = build_input_from_segments(personality, history, current_output, |
|
tokenizer, with_eos=False) |
|
|
|
input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0) |
|
token_type_ids = torch.tensor(instance["token_type_ids"], |
|
device=args.device).unsqueeze(0) |
|
|
|
logits = model(input_ids, token_type_ids=token_type_ids) |
|
if isinstance(logits, tuple): |
|
logits = logits[0] |
|
logits = logits[0, -1, :] / args.temperature |
|
logits = top_filtering(logits, top_k=args.top_k, top_p=args.top_p) |
|
probs = F.softmax(logits, dim=-1) |
|
|
|
prev = torch.topk(probs, 1)[1] if args.no_sample else torch.multinomial(probs, 1) |
|
if i < args.min_length and prev.item() in special_tokens_ids: |
|
while prev.item() in special_tokens_ids: |
|
if probs.max().item() == 1: |
|
warnings.warn( |
|
"Warning: model generating special token with probability 1.") |
|
break |
|
prev = torch.multinomial(probs, num_samples=1) |
|
|
|
if prev.item() in special_tokens_ids: |
|
break |
|
current_output.append(prev.item()) |
|
|
|
return current_output |
|
|
|
|
|
def run(): |
|
parser = ArgumentParser() |
|
parser.add_argument("--dataset_path", type=str, default="", |
|
help="Path or url of the dataset. If empty download from S3.") |
|
parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', |
|
help="Path or url of the dataset cache") |
|
parser.add_argument("--model", type=str, default="openai-gpt", |
|
help="Model type (openai-gpt or gpt2)", choices=['openai-gpt', |
|
'gpt2']) |
|
parser.add_argument("--model_checkpoint", type=str, default="", |
|
help="Path, url or short name of the model") |
|
parser.add_argument("--max_history", type=int, default=2, |
|
help="Number of previous utterances to keep in history") |
|
parser.add_argument("--device", type=str, |
|
default="cuda" if torch.cuda.is_available() else "cpu", |
|
help="Device (cuda or cpu)") |
|
|
|
parser.add_argument("--no_sample", action='store_true', |
|
help="Set to use greedy decoding instead of sampling") |
|
parser.add_argument("--max_length", type=int, default=20, |
|
help="Maximum length of the output utterances") |
|
parser.add_argument("--min_length", type=int, default=1, |
|
help="Minimum length of the output utterances") |
|
parser.add_argument("--seed", type=int, default=0, help="Seed") |
|
parser.add_argument("--temperature", type=float, default=0.7, |
|
help="Sampling softmax temperature") |
|
parser.add_argument("--top_k", type=int, default=0, |
|
help="Filter top-k tokens before sampling (<=0: no filtering)") |
|
parser.add_argument("--top_p", type=float, default=0.9, |
|
help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)") |
|
args = parser.parse_args() |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__file__) |
|
logger.info(pformat(args)) |
|
|
|
if args.model_checkpoint == "": |
|
if args.model == 'gpt2': |
|
raise ValueError( |
|
"Interacting with GPT2 requires passing a finetuned model_checkpoint") |
|
else: |
|
args.model_checkpoint = download_pretrained_model() |
|
|
|
if args.seed != 0: |
|
random.seed(args.seed) |
|
torch.random.manual_seed(args.seed) |
|
torch.cuda.manual_seed(args.seed) |
|
|
|
logger.info("Get pretrained model and tokenizer") |
|
tokenizer_class, model_class = ( |
|
GPT2Tokenizer, GPT2LMHeadModel) if args.model == 'gpt2' else ( |
|
OpenAIGPTTokenizer, OpenAIGPTLMHeadModel) |
|
tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) |
|
model = model_class.from_pretrained(args.model_checkpoint) |
|
model.to(args.device) |
|
add_special_tokens_(model, tokenizer) |
|
|
|
logger.info("Sample a personality") |
|
dataset = get_dataset(tokenizer, args.dataset_path, args.dataset_cache) |
|
personalities = [dialog["personality"] for dataset in dataset.values() for dialog in |
|
dataset] |
|
personality = random.choice(personalities) |
|
logger.info("Selected personality: %s", tokenizer.decode(chain(*personality))) |
|
|
|
history = [] |
|
while True: |
|
raw_text = input(">>> ") |
|
while not raw_text: |
|
print('Prompt should not be empty!') |
|
raw_text = input(">>> ") |
|
history.append(tokenizer.encode(raw_text)) |
|
with torch.no_grad(): |
|
out_ids = sample_sequence(personality, history, tokenizer, model, args) |
|
history.append(out_ids) |
|
history = history[-(2 * args.max_history + 1):] |
|
out_text = tokenizer.decode(out_ids, skip_special_tokens=True) |
|
print(out_text) |
|
|
|
|
|
if __name__ == "__main__": |
|
run() |