Spaces:
Build error
Build error
# Copyright (c) 2022, Lawrence Livermore National Security, LLC. | |
# All rights reserved. | |
# See the top-level LICENSE and NOTICE files for details. | |
# LLNL-CODE-838964 | |
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception | |
from sentence_transformers.cross_encoder import CrossEncoder as CE | |
import numpy as np | |
from typing import List, Dict, Tuple | |
import json | |
from collections import defaultdict | |
import os | |
class CrossEncoder: | |
def __init__(self, | |
model_path: str = None, | |
max_length: int = None, | |
**kwargs): | |
if max_length != None: | |
self.model = CE(model_path, max_length = max_length, **kwargs) | |
self.model = CE(model_path, **kwargs) | |
def predict(self, | |
sentences: List[Tuple[str, str]], | |
batch_size: int = 32, | |
show_progress_bar: bool = False) -> List[float]: | |
return self.model.predict(sentences = sentences, | |
batch_size = batch_size, | |
show_progress_bar = show_progress_bar) | |
class CERank: | |
def __init__(self, model, batch_size: int =128, **kwargs): | |
self.cross_encoder = model | |
self.batch_size = batch_size | |
def flatten_examples(self, contexts: Dict[str, Dict], question: str): | |
text_pairs, pair_ids = [], [] | |
for context_id, context in contexts.items(): | |
pair_ids.append(['question_0', context_id]) | |
text_pairs.append([question, context['text']]) | |
return text_pairs, pair_ids | |
def group_questionrank(self, pair_ids, rank_scores): | |
unsorted = defaultdict(list) | |
for pair, score in zip(pair_ids, rank_scores): | |
query_id, paragraph_id = pair[0], pair[1] | |
unsorted[query_id].append((paragraph_id, score)) | |
return unsorted | |
def get_rankings(self, pair_ids, rank_scores, text_pairs): | |
unsorted_ranks = self.group_questionrank(pair_ids, rank_scores) | |
rankings = defaultdict(dict) | |
for idx, (query_id, ranks) in enumerate(unsorted_ranks.items()): | |
sort_ranks = sorted(ranks, key = lambda item: item[1], reverse = True) | |
sorted_ranks, scores = list(zip(*sort_ranks)) | |
rankings[query_id]['text'] = text_pairs[idx][0] | |
rankings[query_id]['scores'] = list(scores) | |
rankings[query_id]['ranks'] = list(sorted_ranks) | |
return rankings | |
def rank(self, | |
contexts: Dict[str, Dict], | |
question: str): | |
text_pairs, pair_ids = self.flatten_examples(contexts, question) | |
rank_scores = [float(score) for score in self.cross_encoder.predict(text_pairs, batch_size = self.batch_size)] | |
full_results = self.get_rankings(pair_ids, rank_scores, text_pairs) | |
return full_results | |
def get_ranked_contexts(context_json, question): | |
dirname = 'examples' | |
model_path = 'ms-marco-electra-base' | |
max_length = 512 | |
# Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism. | |
cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False}) | |
ranker = CERank(cross_encoder) | |
with open(context_json, 'r') as fin: | |
contexts = json.load(fin) | |
rankings = ranker.rank(contexts, question) | |
with open('ranked_{0}.json'.format(context_json[:-5]), 'w') as fout: | |
json.dump(rankings, fout) | |
def get_ranked_contexts_in_memory(contexts, question): | |
dirname = 'examples' | |
model_path = 'ms-marco-electra-base' | |
max_length = 512 | |
# Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism. | |
cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False}) | |
ranker = CERank(cross_encoder) | |
rankings = ranker.rank(contexts, question) | |
return rankings | |