File size: 3,860 Bytes
10641ee
8bd9363
10641ee
 
 
 
 
8bd9363
 
10641ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bd9363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10641ee
8bd9363
10641ee
 
8bd9363
 
 
 
 
10641ee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from haystack.document_stores import InMemoryDocumentStore

from haystack.nodes.retriever import TfidfRetriever
from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline
from haystack.nodes.retriever import EmbeddingRetriever
import pickle
from pprint import pprint
dutch_datset_name = 'Partisan news 2019 (dutch)'
german_datset_name = 'CDU election program 2021'

class ExportableInMemoryDocumentStore(InMemoryDocumentStore):
    """
    Wrapper class around the InMemoryDocumentStore.
    When the application is deployed to Huggingface Spaces there will be no GPU available.
    We need to load pre-calculated data into the InMemoryDocumentStore.
    """
    def export(self, file_name='in_memory_store.pkl'):
        with open(file_name, 'wb') as f:
            pickle.dump(self.indexes, f)

    def load_data(self, file_name='in_memory_store.pkl'):
        with open(file_name, 'rb') as f:
            self.indexes = pickle.load(f)


class SearchEngine():

    def __init__(self, document_store_name_base, document_store_name_adpated,
                 adapted_retriever_path):
        self.document_store = ExportableInMemoryDocumentStore(similarity='cosine')
        self.document_store.load_data(document_store_name_base)

        self.document_store_adapted = ExportableInMemoryDocumentStore(similarity='cosine')
        self.document_store_adapted.load_data(document_store_name_adpated)

        self.retriever = TfidfRetriever(document_store=self.document_store)

        self.base_dense_retriever = EmbeddingRetriever(
            document_store=self.document_store,
            embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
            model_format='sentence_transformers'
        )

        self.fine_tuned_retriever = EmbeddingRetriever(
            document_store=self.document_store_adapted,
            embedding_model=adapted_retriever_path,
            model_format='sentence_transformers'
        )

    def sparse_retrieval(self, query):
        """Sparse retrieval pipeline"""
        scores = self.retriever._calc_scores(query)
        p_retrieval = DocumentSearchPipeline(self.retriever)
        documents = p_retrieval.run(query=query)
        documents['documents'][0].score = list(scores[0].values())[0]
        return documents

    def dense_retrieval(self, query, retriever='base'):
        if retriever == 'base':
            p_retrieval = DocumentSearchPipeline(self.base_dense_retriever)
            return p_retrieval.run(query=query)
        if retriever == 'adapted':
            p_retrieval = DocumentSearchPipeline(self.fine_tuned_retriever)
            return p_retrieval.run(query=query)

    def do_search(self, query):
        sparse_result = self.sparse_retrieval(query)['documents'][0]
        dense_base_result = self.dense_retrieval(query, 'base')['documents'][0]
        dense_adapted_result = self.dense_retrieval(query, 'adapted')['documents'][0]
        return sparse_result, dense_base_result, dense_adapted_result


dutch_search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl',
                                     'dutch-article-retriever')
german_search_engine = SearchEngine('documentstore_german-election-idx.pkl',
                                        'documentstore_german-election-idx_adapted.pkl',
                                        'adapted-retriever')

def do_search(query, dataset):
    if dataset == german_datset_name:
        return german_search_engine.do_search(query)
    else:
        return dutch_search_engine.do_search(query)

if __name__ == '__main__':
    search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl',
                                 'dutch-article-retriever')
    query = 'Kindergarten'

    result = search_engine.do_search(query)
    pprint(result)