File size: 1,383 Bytes
af461f3
51a31d4
1fb8ae3
af461f3
 
 
1fb8ae3
51a31d4
 
 
 
af461f3
 
1fb8ae3
af461f3
 
 
51a31d4
af461f3
a1746cf
 
 
af461f3
 
 
a1746cf
 
af461f3
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from datasets import DatasetDict, load_dataset
from src.utils.log import get_logger
from src.retrievers.base_retriever import Retriever
from elasticsearch import Elasticsearch
from dotenv import load_dotenv
import os

logger = get_logger()


class ESRetriever(Retriever):
    def __init__(self, dataset: DatasetDict) -> None:
        self.dataset = dataset["train"]

        es_host = os.getenv("ELASTIC_HOST", "localhost")
        es_password = os.getenv("ELASTIC_PASSWORD")
        es_username = os.getenv("ELASTIC_USERNAME")

        self.client = Elasticsearch(
            hosts=[es_host],
            http_auth=(es_username, es_password),
            ca_certs="./http_ca.crt")

        if self.client.indices.exists(index="paragraphs"):
            self.dataset.load_elasticsearch_index(
                "paragraphs", es_index_name="paragraphs",
                es_client=self.client)
        else:
            logger.info(f"Creating index 'paragraphs' on {es_host}")
            self.dataset.add_elasticsearch_index(column="text",
                                                 index_name="paragraphs",
                                                 es_index_name="paragraphs",
                                                 es_client=self.client)

    def retrieve(self, query: str, k: int = 5):
        return self.dataset.get_nearest_examples("paragraphs", query, k)