olamidegoriola pvanand commited on
Commit
8e6ed3b
1 Parent(s): 61f0337

Create FAISS index using omdena qna dataset (#4)

Browse files

- Create FAISS index using omdena qna dataset (b38f7550122233810db9ef6a3aa08a90a1959fdd)


Co-authored-by: Anand <[email protected]>

Files changed (1) hide show
  1. create_faiss_index.py +58 -0
create_faiss_index.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """create_faiss_index.py
3
+ """
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+ import faiss
8
+ from sentence_transformers import InputExample, SentenceTransformer
9
+
10
+ DATA_FILE_PATH = "omdena_qna_dataset/omdena_faq_training_data.csv"
11
+ TRANSFORMER_MODEL_NAME = "all-distilroberta-v1"
12
+ CACHE_DIR_PATH = "../working/cache/"
13
+ MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl"
14
+ FAISS_INDEX_FILE_PATH = "index.faiss"
15
+
16
+ def load_data(file_path):
17
+ qna_dataset = pd.read_csv(file_path)
18
+ qna_dataset["id"] = qna_dataset.index
19
+ return qna_dataset.dropna(subset=['Answers']).copy()
20
+
21
+ def create_input_examples(qna_dataset):
22
+ qna_dataset['QNA'] = qna_dataset.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1)
23
+ return qna_dataset.apply(lambda x: InputExample(texts=[x["QNA"]]), axis=1).tolist()
24
+
25
+ def load_transformer_model(model_name, cache_folder):
26
+ transformer_model = SentenceTransformer(model_name, cache_folder=cache_folder)
27
+ return transformer_model
28
+
29
+ def save_transformer_model(transformer_model, model_file):
30
+ transformer_model.save(model_file)
31
+
32
+ def create_faiss_index(transformer_model, qna_dataset):
33
+ faiss_embeddings = transformer_model.encode(qna_dataset.Answers.values.tolist())
34
+ qna_dataset_indexed = qna_dataset.set_index(["id"], drop=False)
35
+ id_index_array = np.array(qna_dataset_indexed.id.values).flatten().astype("int")
36
+ normalized_embeddings = faiss_embeddings.copy()
37
+ faiss.normalize_L2(normalized_embeddings)
38
+ faiss_index = faiss.IndexIDMap(faiss.IndexFlatIP(len(faiss_embeddings[0])))
39
+ faiss_index.add_with_ids(normalized_embeddings, id_index_array)
40
+ return faiss_index
41
+
42
+ def save_faiss_index(faiss_index, filename):
43
+ faiss.write_index(faiss_index, filename)
44
+
45
+ def load_faiss_index(filename):
46
+ return faiss.read_index(filename)
47
+
48
+ def main():
49
+ qna_dataset = load_data(DATA_FILE_PATH)
50
+ input_examples = create_input_examples(qna_dataset)
51
+ transformer_model = load_transformer_model(TRANSFORMER_MODEL_NAME, CACHE_DIR_PATH)
52
+ save_transformer_model(transformer_model, MODEL_SAVE_PATH)
53
+ faiss_index = create_faiss_index(transformer_model, qna_dataset)
54
+ save_faiss_index(faiss_index, FAISS_INDEX_FILE_PATH)
55
+ faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH)
56
+
57
+ if __name__ == "__main__":
58
+ main()