Ramon Meffert
Add reader
ab5dfc2
raw
history blame
4.24 kB
import os
import os.path
import torch
from datasets import load_dataset
from transformers import (
DPRContextEncoder,
DPRContextEncoderTokenizer,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer,
)
from src.retrievers.base_retriever import Retriever
from src.utils.log import get_logger
# Hacky fix for FAISS error on macOS
# See https://stackoverflow.com/a/63374568/4545692
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
logger = get_logger()
class FaissRetriever(Retriever):
"""A class used to retrieve relevant documents based on some query.
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
"""
def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp") -> None:
"""Initialize the retriever
Args:
dataset (str, optional): The dataset to train on. Assumes the
information is stored in a column named 'text'. Defaults to
"GroNLP/ik-nlp-22_slp".
"""
torch.set_grad_enabled(False)
# Context encoding and tokenization
self.ctx_encoder = DPRContextEncoder.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
# Question encoding and tokenization
self.q_encoder = DPRQuestionEncoder.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
# Dataset building
self.dataset_name = dataset_name
self.dataset = self._init_dataset(dataset_name)
def _init_dataset(
self,
dataset_name: str,
embedding_path: str = "./src/models/paragraphs_embedding.faiss",
force_new_embedding: bool = False):
"""Loads the dataset and adds FAISS embeddings.
Args:
dataset (str): A HuggingFace dataset name.
fname (str): The name to use to save the embeddings to disk for
faster loading after the first run.
Returns:
Dataset: A dataset with a new column 'embeddings' containing FAISS
embeddings.
"""
# Load dataset
ds = load_dataset(dataset_name, name="paragraphs")[
"train"] # type: ignore
if not force_new_embedding and os.path.exists(embedding_path):
# If we already have FAISS embeddings, load them from disk
ds.load_faiss_index('embeddings', embedding_path) # type: ignore
return ds
else:
# If there are no FAISS embeddings, generate them
def embed(row):
# Inline helper function to perform embedding
p = row["text"]
tok = self.ctx_tokenizer(
p, return_tensors="pt", truncation=True)
enc = self.ctx_encoder(**tok)[0][0].numpy()
return {"embeddings": enc}
# Add FAISS embeddings
ds_with_embeddings = ds.map(embed) # type: ignore
ds_with_embeddings.add_faiss_index(column="embeddings")
# save dataset w/ embeddings
os.makedirs("./src/models/", exist_ok=True)
ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
return ds_with_embeddings
def retrieve(self, query: str, k: int = 5):
"""Retrieve the top k matches for a search query.
Args:
query (str): A search query
k (int, optional): The number of documents to retrieve. Defaults to
5.
Returns:
tuple: A tuple of lists of scores and results.
"""
def embed(q):
# Inline helper function to perform embedding
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
return self.q_encoder(**tok)[0][0].numpy()
question_embedding = embed(query)
scores, results = self.dataset.get_nearest_examples(
"embeddings", question_embedding, k=k
)
return scores, results