File size: 4,237 Bytes
83870cc
51dabd6
 
 
 
51a31d4
 
 
 
 
 
51dabd6
51a31d4
 
8bbe3aa
51a31d4
 
ab5dfc2
51a31d4
83870cc
51a31d4
83870cc
51a31d4
ab5dfc2
8bbe3aa
 
 
 
aa426fb
8bbe3aa
 
 
 
 
 
 
 
 
 
 
83870cc
 
8bbe3aa
83870cc
 
8bbe3aa
 
 
83870cc
 
8bbe3aa
83870cc
 
8bbe3aa
 
aa426fb
 
8bbe3aa
ab5dfc2
 
 
 
 
8bbe3aa
 
 
 
ab5dfc2
83870cc
8bbe3aa
 
 
 
 
 
51dabd6
 
8bbe3aa
ab5dfc2
83870cc
51dabd6
83870cc
 
 
 
 
 
 
 
 
 
 
 
51dabd6
83870cc
 
8bbe3aa
83870cc
ab5dfc2
aa426fb
8bbe3aa
83870cc
8bbe3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83870cc
8bbe3aa
 
 
 
83870cc
 
2827202
8bbe3aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import os.path

import torch
from datasets import load_dataset
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer,
)

from src.retrievers.base_retriever import Retriever
from src.utils.log import get_logger

# Hacky fix for FAISS error on macOS
# See https://stackoverflow.com/a/63374568/4545692
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"


logger = get_logger()


class FaissRetriever(Retriever):
    """A class used to retrieve relevant documents based on some query.
    based on https://huggingface.co/docs/datasets/faiss_es#faiss.
    """

    def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp") -> None:
        """Initialize the retriever

        Args:
            dataset (str, optional): The dataset to train on. Assumes the
            information is stored in a column named 'text'. Defaults to
            "GroNLP/ik-nlp-22_slp".
        """
        torch.set_grad_enabled(False)

        # Context encoding and tokenization
        self.ctx_encoder = DPRContextEncoder.from_pretrained(
            "facebook/dpr-ctx_encoder-single-nq-base"
        )
        self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
            "facebook/dpr-ctx_encoder-single-nq-base"
        )

        # Question encoding and tokenization
        self.q_encoder = DPRQuestionEncoder.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )

        # Dataset building
        self.dataset_name = dataset_name
        self.dataset = self._init_dataset(dataset_name)

    def _init_dataset(
            self,
            dataset_name: str,
            embedding_path: str = "./src/models/paragraphs_embedding.faiss",
            force_new_embedding: bool = False):
        """Loads the dataset and adds FAISS embeddings.

        Args:
            dataset (str): A HuggingFace dataset name.
            fname (str): The name to use to save the embeddings to disk for
            faster loading after the first run.

        Returns:
            Dataset: A dataset with a new column 'embeddings' containing FAISS
            embeddings.
        """
        # Load dataset
        ds = load_dataset(dataset_name, name="paragraphs")[
            "train"]  # type: ignore

        if not force_new_embedding and os.path.exists(embedding_path):
            # If we already have FAISS embeddings, load them from disk
            ds.load_faiss_index('embeddings', embedding_path)  # type: ignore
            return ds
        else:
            # If there are no FAISS embeddings, generate them
            def embed(row):
                # Inline helper function to perform embedding
                p = row["text"]
                tok = self.ctx_tokenizer(
                    p, return_tensors="pt", truncation=True)
                enc = self.ctx_encoder(**tok)[0][0].numpy()
                return {"embeddings": enc}

            # Add FAISS embeddings
            ds_with_embeddings = ds.map(embed)  # type: ignore

            ds_with_embeddings.add_faiss_index(column="embeddings")

            # save dataset w/ embeddings
            os.makedirs("./src/models/", exist_ok=True)
            ds_with_embeddings.save_faiss_index("embeddings", embedding_path)

            return ds_with_embeddings

    def retrieve(self, query: str, k: int = 5):
        """Retrieve the top k matches for a search query.

        Args:
            query (str): A search query
            k (int, optional): The number of documents to retrieve. Defaults to
            5.

        Returns:
            tuple: A tuple of lists of scores and results.
        """

        def embed(q):
            # Inline helper function to perform embedding
            tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
            return self.q_encoder(**tok)[0][0].numpy()

        question_embedding = embed(query)
        scores, results = self.dataset.get_nearest_examples(
            "embeddings", question_embedding, k=k
        )

        return scores, results