Ramon Meffert
Small fixes to retrievers
a1746cf
raw
history blame
1.38 kB
from datasets import DatasetDict, load_dataset
from src.utils.log import get_logger
from src.retrievers.base_retriever import Retriever
from elasticsearch import Elasticsearch
from dotenv import load_dotenv
import os
logger = get_logger()
class ESRetriever(Retriever):
def __init__(self, dataset: DatasetDict) -> None:
self.dataset = dataset["train"]
es_host = os.getenv("ELASTIC_HOST", "localhost")
es_password = os.getenv("ELASTIC_PASSWORD")
es_username = os.getenv("ELASTIC_USERNAME")
self.client = Elasticsearch(
hosts=[es_host],
http_auth=(es_username, es_password),
ca_certs="./http_ca.crt")
if self.client.indices.exists(index="paragraphs"):
self.dataset.load_elasticsearch_index(
"paragraphs", es_index_name="paragraphs",
es_client=self.client)
else:
logger.info(f"Creating index 'paragraphs' on {es_host}")
self.dataset.add_elasticsearch_index(column="text",
index_name="paragraphs",
es_index_name="paragraphs",
es_client=self.client)
def retrieve(self, query: str, k: int = 5):
return self.dataset.get_nearest_examples("paragraphs", query, k)