Spaces:
Build error
Build error
File size: 7,058 Bytes
c14d9ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
import sys
import json
from math import ceil
import torch
import numpy as np
from torch import tensor
from torch.nn.functional import log_softmax
from torch.distributions.categorical import Categorical
from transformers import T5Tokenizer, T5ForConditionalGeneration
# load UnifiedQA onto device
model_name = "allenai/unifiedqa-v2-t5-large-1363200"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
def get_inputs(contexts_json, ranked_contexts_json):
with open(contexts_json, 'rt') as fp:
contexts = json.load(fp)
with open(ranked_contexts_json, 'rt') as fp:
ranked_contexts = json.load(fp)
question_id = list(ranked_contexts.keys())[0]
# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
question = ranked_contexts[question_id]['text']
context_ids_sorted = ranked_contexts[question_id]['ranks']
context_scores = ranked_contexts[question_id]['scores']
contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]
# returns the question (str) and its contexts (sequence)
return question, contexts, context_scores
def get_tokens(text, tokenizer, max_tokens):
return tokenizer.encode_plus(text, return_tensors='pt', max_length=max_tokens, padding='max_length', truncation=True)['input_ids']
def prepare_inputs(tokenizer, max_tokens, context, question):
input_str = f'{question} \\n {context}'
inputs = get_tokens(input_str, tokenizer, max_tokens)
return inputs
def get_outputs(model, tokenizer, input_tokens, max_tokens):
output_dict = model.generate(input_tokens, output_scores=True, return_dict_in_generate=True, **{'max_length': max_tokens})
pred_tokens = output_dict['sequences'].squeeze().tolist()
# initialize metrics
logit_entropy = []
sentence_probs = []
# accumulate metrics over logit_sequence
logit_sequence = output_dict['scores'][:-1] # discard end token
for logit in logit_sequence:
log_probs = log_softmax(logit, dim=-1)
# update metrics
logit_entropy.append(Categorical(log_probs.exp()).entropy())
sentence_probs.append(log_probs.max())
# finish metrics calculation
logit_entropy = tensor(logit_entropy)
sentence_probs = tensor(sentence_probs)
entropy = logit_entropy.mean()
sentence_std = 0 if len(logit_sequence) == 1 else sentence_probs.std(unbiased=True).exp()
# use entropy * sentence_std as uncertainty
uncertainty = (entropy * sentence_std).item()
# convert answer tokens to str
pred_str = tokenizer.decode(pred_tokens, skip_special_tokens=True).lower()
return pred_str, uncertainty
# k_percent: percentage of contexts to use, cannot be less than min_k or greater than max_k
# min_k: minimum number of contexts to use, if possible. Setting this too small reduces recall
# max_k: maximum number of contexts to use. Setting this too big reduces precision
# recommended uncertainty thresholds are 2,3,4, and 5. The lower the threshold, the more aggressive the filtering
def run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=0.1, min_k=10, max_k=25, uncertainty_thresh=3):
k = min(max(ceil(k_percent * len(contexts)), min_k), max_k)
contexts = contexts[:k]
context_scores = context_scores[:k]
# iterate through top-k contexts
answers = []
uncertainty = []
for context in contexts:
input_tokens = prepare_inputs(tokenizer, 512, context, question).to(device)
pred_str, uncertainty_1 = get_outputs(model, tokenizer, input_tokens, 512)
answers.append(pred_str)
uncertainty.append(uncertainty_1)
# contexts = np.array(contexts)
# answers = np.array(answers)
# uncertainty = np.array(uncertainty)
# sort by uncertainty, ascending order
# order = np.argsort(uncertainty)
# contexts = contexts[order]
# answers = answers[order]
# uncertainty = uncertainty[order]
# init lists for threshed answers
# weak_contexts = []
# weak_answers = []
# weak_uncertainty = []
# filter by uncertainty
# if len(answers) > min_k:
# weak = np.argwhere(uncertainty > uncertainty_thresh) # exceeds threshold
# weak_contexts = contexts[weak].tolist()
# weak_answers = answers[weak].tolist()
# weak_uncertainty = uncertainty[weak].tolist()
# strong = np.argwhere(uncertainty <= uncertainty_thresh) # within threshold
# contexts = contexts[strong]
# answers = answers[strong]
# uncertainty = uncertainty[strong]
# contexts = contexts.tolist()
# answers = answers.tolist()
# uncertainty = uncertainty.tolist()
# return {'contexts': contexts, 'answers': answers, 'uncertainty': uncertainty}, \
# {'contexts': weak_contexts, 'answers': weak_answers, 'uncertainty': weak_uncertainty}
return {'contexts': contexts, 'answers': answers, 'context_scores':context_scores, 'uncertainty': uncertainty}
def get_qa_results(contexts_json, ranked_contexts_json, topk):
# extract question and contexts from json
question, contexts, context_scores = get_inputs(contexts_json, ranked_contexts_json)
# infer answers
with torch.inference_mode(True):
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
return qa_results
def get_qa_results_in_memory(contexts, ranked_contexts, topk):
question_id = list(ranked_contexts.keys())[0]
# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
question = ranked_contexts[question_id]['text']
context_ids_sorted = ranked_contexts[question_id]['ranks']
context_scores = ranked_contexts[question_id]['scores']
contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]
# infer answers
with torch.inference_mode(True):
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
return qa_results
def load_custom_model(finetuned_model_path):
global tokenizer
global model
# load UnifiedQA onto device
tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path)
model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path)
model.to(device)
def get_qa_results_in_memory_finetuned_unifiedqa(question, context_scores, contexts, topk):
# infer answers
with torch.inference_mode(True):
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
return qa_results
|