File size: 4,237 Bytes
83870cc 51dabd6 51a31d4 51dabd6 51a31d4 8bbe3aa 51a31d4 ab5dfc2 51a31d4 83870cc 51a31d4 83870cc 51a31d4 ab5dfc2 8bbe3aa aa426fb 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa aa426fb 8bbe3aa ab5dfc2 8bbe3aa ab5dfc2 83870cc 8bbe3aa 51dabd6 8bbe3aa ab5dfc2 83870cc 51dabd6 83870cc 51dabd6 83870cc 8bbe3aa 83870cc ab5dfc2 aa426fb 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 2827202 8bbe3aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import os.path
import torch
from datasets import load_dataset
from transformers import (
DPRContextEncoder,
DPRContextEncoderTokenizer,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer,
)
from src.retrievers.base_retriever import Retriever
from src.utils.log import get_logger
# Hacky fix for FAISS error on macOS
# See https://stackoverflow.com/a/63374568/4545692
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
logger = get_logger()
class FaissRetriever(Retriever):
"""A class used to retrieve relevant documents based on some query.
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
"""
def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp") -> None:
"""Initialize the retriever
Args:
dataset (str, optional): The dataset to train on. Assumes the
information is stored in a column named 'text'. Defaults to
"GroNLP/ik-nlp-22_slp".
"""
torch.set_grad_enabled(False)
# Context encoding and tokenization
self.ctx_encoder = DPRContextEncoder.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
# Question encoding and tokenization
self.q_encoder = DPRQuestionEncoder.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
# Dataset building
self.dataset_name = dataset_name
self.dataset = self._init_dataset(dataset_name)
def _init_dataset(
self,
dataset_name: str,
embedding_path: str = "./src/models/paragraphs_embedding.faiss",
force_new_embedding: bool = False):
"""Loads the dataset and adds FAISS embeddings.
Args:
dataset (str): A HuggingFace dataset name.
fname (str): The name to use to save the embeddings to disk for
faster loading after the first run.
Returns:
Dataset: A dataset with a new column 'embeddings' containing FAISS
embeddings.
"""
# Load dataset
ds = load_dataset(dataset_name, name="paragraphs")[
"train"] # type: ignore
if not force_new_embedding and os.path.exists(embedding_path):
# If we already have FAISS embeddings, load them from disk
ds.load_faiss_index('embeddings', embedding_path) # type: ignore
return ds
else:
# If there are no FAISS embeddings, generate them
def embed(row):
# Inline helper function to perform embedding
p = row["text"]
tok = self.ctx_tokenizer(
p, return_tensors="pt", truncation=True)
enc = self.ctx_encoder(**tok)[0][0].numpy()
return {"embeddings": enc}
# Add FAISS embeddings
ds_with_embeddings = ds.map(embed) # type: ignore
ds_with_embeddings.add_faiss_index(column="embeddings")
# save dataset w/ embeddings
os.makedirs("./src/models/", exist_ok=True)
ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
return ds_with_embeddings
def retrieve(self, query: str, k: int = 5):
"""Retrieve the top k matches for a search query.
Args:
query (str): A search query
k (int, optional): The number of documents to retrieve. Defaults to
5.
Returns:
tuple: A tuple of lists of scores and results.
"""
def embed(q):
# Inline helper function to perform embedding
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
return self.q_encoder(**tok)[0][0].numpy()
question_embedding = embed(query)
scores, results = self.dataset.get_nearest_examples(
"embeddings", question_embedding, k=k
)
return scores, results
|