File size: 1,741 Bytes
0157dfd
b06298d
 
e9df5ab
af461f3
0157dfd
 
b06298d
 
0157dfd
b06298d
1fb8ae3
0157dfd
 
51a31d4
 
 
e9df5ab
 
1fb8ae3
af461f3
 
 
51a31d4
af461f3
a1746cf
 
 
af461f3
0157dfd
 
 
 
 
 
 
af461f3
e9df5ab
a1746cf
 
af461f3
 
e9df5ab
 
 
 
af461f3
b06298d
e9df5ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import imp
import os

from datasets import DatasetDict
from elasticsearch import Elasticsearch
from elastic_transport import ConnectionError
from dotenv import load_dotenv

from src.retrievers.base_retriever import RetrieveType, Retriever
from src.utils.log import logger
from src.utils.timing import timeit


load_dotenv()


class ESRetriever(Retriever):
    def __init__(self, paragraphs: DatasetDict) -> None:
        self.paragraphs = paragraphs["train"]

        es_host = os.getenv("ELASTIC_HOST", "localhost")
        es_password = os.getenv("ELASTIC_PASSWORD")
        es_username = os.getenv("ELASTIC_USERNAME")

        self.client = Elasticsearch(
            hosts=[es_host],
            http_auth=(es_username, es_password),
            ca_certs="./http_ca.crt")

        try:
            self.client.info()
        except ConnectionError:
            logger.error("Could not connect to ElasticSearch. " + 
                         "Make sure it is running. Exiting now...")
            exit()

        if self.client.indices.exists(index="paragraphs"):
            self.paragraphs.load_elasticsearch_index(
                "paragraphs", es_index_name="paragraphs",
                es_client=self.client)
        else:
            logger.info(f"Creating index 'paragraphs' on {es_host}")
            self.paragraphs.add_elasticsearch_index(column="text",
                                                    index_name="paragraphs",
                                                    es_index_name="paragraphs",
                                                    es_client=self.client)

    def retrieve(self, query: str, k: int = 5) -> RetrieveType:
        return self.paragraphs.get_nearest_examples("paragraphs", query, k)