File size: 3,104 Bytes
8bbe3aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, \
DPRQuestionEncoder, DPRQuestionEncoderTokenizer
from datasets import load_dataset
import torch
class Retriever():
"""A class used to retrieve relevant documents based on some query.
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
"""
def __init__(self, dataset: str = "GroNLP/ik-nlp-22_slp") -> None:
"""Initialize the retriever
Args:
dataset (str, optional): The dataset to train on. Assumes the
information is stored in a column named 'text'. Defaults to
"GroNLP/ik-nlp-22_slp".
"""
torch.set_grad_enabled(False)
# Context encoding and tokenization
self.ctx_encoder = DPRContextEncoder.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base")
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base")
# Question encoding and tokenization
self.q_encoder = DPRQuestionEncoder.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base")
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base")
# Dataset building
self.dataset = self.__init_dataset(dataset)
def __init_dataset(self, dataset: str):
"""Loads the dataset and adds FAISS embeddings.
Args:
dataset (str): A HuggingFace dataset name.
Returns:
Dataset: A dataset with a new column 'embeddings' containing FAISS
embeddings.
"""
# TODO: save ds w/ embeddings to disk and retrieve it if it already exists
# Load dataset
ds = load_dataset(dataset, name='paragraphs')['train']
def embed(row):
# Inline helper function to perform embedding
p = row['text']
tok = self.ctx_tokenizer(p, return_tensors='pt', truncation=True)
enc = self.ctx_encoder(**tok)[0][0].numpy()
return {'embeddings': enc}
# Add FAISS embeddings
ds_with_embeddings = ds.map(embed)
# Todo: this throws a weird error.
ds_with_embeddings.add_faiss_index(column='embeddings')
return ds_with_embeddings
def retrieve(self, query: str, k: int = 5):
"""Retrieve the top k matches for a search query.
Args:
query (str): A search query
k (int, optional): The number of documents to retrieve. Defaults to
5.
Returns:
tuple: A tuple of lists of scores and results.
"""
def embed(q):
# Inline helper function to perform embedding
tok = self.q_tokenizer(q, return_tensors='pt', truncation=True)
return self.q_encoder(**tok)[0][0].numpy()
question_embedding = embed(query)
scores, results = self.dataset.get_nearest_examples(
'embeddings', question_embedding, k=k)
return scores, results
|