File size: 1,741 Bytes
0157dfd b06298d e9df5ab af461f3 0157dfd b06298d 0157dfd b06298d 1fb8ae3 0157dfd 51a31d4 e9df5ab 1fb8ae3 af461f3 51a31d4 af461f3 a1746cf af461f3 0157dfd af461f3 e9df5ab a1746cf af461f3 e9df5ab af461f3 b06298d e9df5ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import imp
import os
from datasets import DatasetDict
from elasticsearch import Elasticsearch
from elastic_transport import ConnectionError
from dotenv import load_dotenv
from src.retrievers.base_retriever import RetrieveType, Retriever
from src.utils.log import logger
from src.utils.timing import timeit
load_dotenv()
class ESRetriever(Retriever):
def __init__(self, paragraphs: DatasetDict) -> None:
self.paragraphs = paragraphs["train"]
es_host = os.getenv("ELASTIC_HOST", "localhost")
es_password = os.getenv("ELASTIC_PASSWORD")
es_username = os.getenv("ELASTIC_USERNAME")
self.client = Elasticsearch(
hosts=[es_host],
http_auth=(es_username, es_password),
ca_certs="./http_ca.crt")
try:
self.client.info()
except ConnectionError:
logger.error("Could not connect to ElasticSearch. " +
"Make sure it is running. Exiting now...")
exit()
if self.client.indices.exists(index="paragraphs"):
self.paragraphs.load_elasticsearch_index(
"paragraphs", es_index_name="paragraphs",
es_client=self.client)
else:
logger.info(f"Creating index 'paragraphs' on {es_host}")
self.paragraphs.add_elasticsearch_index(column="text",
index_name="paragraphs",
es_index_name="paragraphs",
es_client=self.client)
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
return self.paragraphs.get_nearest_examples("paragraphs", query, k)
|