import imp import os from datasets import DatasetDict from elasticsearch import Elasticsearch from elastic_transport import ConnectionError from dotenv import load_dotenv from src.retrievers.base_retriever import RetrieveType, Retriever from src.utils.log import logger from src.utils.timing import timeit load_dotenv() class ESRetriever(Retriever): def __init__(self, paragraphs: DatasetDict) -> None: self.paragraphs = paragraphs["train"] es_host = os.getenv("ELASTIC_HOST", "localhost") es_password = os.getenv("ELASTIC_PASSWORD") es_username = os.getenv("ELASTIC_USERNAME") self.client = Elasticsearch( hosts=[es_host], http_auth=(es_username, es_password), ca_certs="./http_ca.crt") try: self.client.info() except ConnectionError: logger.error("Could not connect to ElasticSearch. " + "Make sure it is running. Exiting now...") exit() if self.client.indices.exists(index="paragraphs"): self.paragraphs.load_elasticsearch_index( "paragraphs", es_index_name="paragraphs", es_client=self.client) else: logger.info(f"Creating index 'paragraphs' on {es_host}") self.paragraphs.add_elasticsearch_index(column="text", index_name="paragraphs", es_index_name="paragraphs", es_client=self.client) def retrieve(self, query: str, k: int = 5) -> RetrieveType: return self.paragraphs.get_nearest_examples("paragraphs", query, k)