Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""create_faiss_index.py | |
""" | |
import pandas as pd | |
import numpy as np | |
import faiss | |
from sentence_transformers import InputExample, SentenceTransformer | |
DATA_FILE_PATH = "omdena_qna_dataset/omdena_faq_training_data.csv" | |
TRANSFORMER_MODEL_NAME = "all-distilroberta-v1" | |
CACHE_DIR_PATH = "../working/cache/" | |
MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl" | |
FAISS_INDEX_FILE_PATH = "index.faiss" | |
def load_data(file_path): | |
qna_dataset = pd.read_csv(file_path) | |
qna_dataset["id"] = qna_dataset.index | |
return qna_dataset.dropna(subset=['Answers']).copy() | |
def create_input_examples(qna_dataset): | |
qna_dataset['QNA'] = qna_dataset.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1) | |
return qna_dataset.apply(lambda x: InputExample(texts=[x["QNA"]]), axis=1).tolist() | |
def load_transformer_model(model_name, cache_folder): | |
transformer_model = SentenceTransformer(model_name, cache_folder=cache_folder) | |
return transformer_model | |
def save_transformer_model(transformer_model, model_file): | |
transformer_model.save(model_file) | |
def create_faiss_index(transformer_model, qna_dataset): | |
faiss_embeddings = transformer_model.encode(qna_dataset.Answers.values.tolist()) | |
qna_dataset_indexed = qna_dataset.set_index(["id"], drop=False) | |
id_index_array = np.array(qna_dataset_indexed.id.values).flatten().astype("int") | |
normalized_embeddings = faiss_embeddings.copy() | |
faiss.normalize_L2(normalized_embeddings) | |
faiss_index = faiss.IndexIDMap(faiss.IndexFlatIP(len(faiss_embeddings[0]))) | |
faiss_index.add_with_ids(normalized_embeddings, id_index_array) | |
return faiss_index | |
def save_faiss_index(faiss_index, filename): | |
faiss.write_index(faiss_index, filename) | |
def load_faiss_index(filename): | |
return faiss.read_index(filename) | |
def main(): | |
qna_dataset = load_data(DATA_FILE_PATH) | |
input_examples = create_input_examples(qna_dataset) | |
transformer_model = load_transformer_model(TRANSFORMER_MODEL_NAME, CACHE_DIR_PATH) | |
save_transformer_model(transformer_model, MODEL_SAVE_PATH) | |
faiss_index = create_faiss_index(transformer_model, qna_dataset) | |
save_faiss_index(faiss_index, FAISS_INDEX_FILE_PATH) | |
faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH) | |
if __name__ == "__main__": | |
main() |