File size: 1,383 Bytes
af461f3 51a31d4 1fb8ae3 af461f3 1fb8ae3 51a31d4 af461f3 1fb8ae3 af461f3 51a31d4 af461f3 a1746cf af461f3 a1746cf af461f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from datasets import DatasetDict, load_dataset
from src.utils.log import get_logger
from src.retrievers.base_retriever import Retriever
from elasticsearch import Elasticsearch
from dotenv import load_dotenv
import os
logger = get_logger()
class ESRetriever(Retriever):
def __init__(self, dataset: DatasetDict) -> None:
self.dataset = dataset["train"]
es_host = os.getenv("ELASTIC_HOST", "localhost")
es_password = os.getenv("ELASTIC_PASSWORD")
es_username = os.getenv("ELASTIC_USERNAME")
self.client = Elasticsearch(
hosts=[es_host],
http_auth=(es_username, es_password),
ca_certs="./http_ca.crt")
if self.client.indices.exists(index="paragraphs"):
self.dataset.load_elasticsearch_index(
"paragraphs", es_index_name="paragraphs",
es_client=self.client)
else:
logger.info(f"Creating index 'paragraphs' on {es_host}")
self.dataset.add_elasticsearch_index(column="text",
index_name="paragraphs",
es_index_name="paragraphs",
es_client=self.client)
def retrieve(self, query: str, k: int = 5):
return self.dataset.get_nearest_examples("paragraphs", query, k)
|