ASledziewska commited on
Commit
7145d2f
1 Parent(s): ab8f613

Create Chromadb_storage_JyotiNigam.py

Browse files
Files changed (1) hide show
  1. Chromadb_storage_JyotiNigam.py +63 -0
Chromadb_storage_JyotiNigam.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from nltk.tokenize import word_tokenize
3
+ from langchain_community.document_loaders import TextLoader
4
+ from langchain_community.embeddings.sentence_transformer import (
5
+ SentenceTransformerEmbeddings,
6
+ )
7
+ from langchain_community.vectorstores import Chroma
8
+ from langchain_text_splitters import CharacterTextSplitter
9
+
10
+ # Download NLTK data for tokenization
11
+ nltk.download('punkt')
12
+ import os
13
+ global db
14
+ class QuestionRetriever:
15
+
16
+ def load_documents(self,file_name):
17
+ data_directory = "data/"
18
+ file_path = os.path.join(data_directory, file_name)
19
+ loader = TextLoader(file_path)
20
+ documents = loader.load()
21
+ return documents
22
+
23
+ def store_data_in_vector_db(self,documents):
24
+ # global db
25
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0,separator="\n")
26
+ docs = text_splitter.split_documents(documents)
27
+ # create the open-source embedding function
28
+ embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
29
+ # print(docs)
30
+ # load it into Chroma
31
+ db = Chroma.from_documents(docs, embedding_function)
32
+ return db
33
+
34
+ def get_response(self, user_query, predicted_mental_category):
35
+ if predicted_mental_category == "depression":
36
+ documents=self.load_documents("depression_questions.txt")
37
+
38
+ elif predicted_mental_category == "adhd":
39
+ documents=self.load_documents("adhd_questions.txt")
40
+
41
+ elif predicted_mental_category == "anxiety":
42
+ documents=self.load_documents("anxiety_questions.txt")
43
+
44
+ else:
45
+ print("Sorry, allowed predicted_mental_category is ['depresison', 'adhd', 'anxiety'].")
46
+ return
47
+ db=self.store_data_in_vector_db(documents)
48
+
49
+ docs = db.similarity_search(user_query)
50
+ most_similar_question = docs[0].page_content.split("\n")[0] # Extract the first question
51
+ if user_query==most_similar_question:
52
+ most_similar_question=docs[1].page_content.split("\n")[0]
53
+
54
+ print(most_similar_question)
55
+ return most_similar_question
56
+
57
+ if __name__ == "__main__":
58
+ model = QuestionRetriever()
59
+ user_input = input("User: ")
60
+
61
+ predicted_mental_condition = "depression"
62
+ response = model.get_response(user_input, predicted_mental_condition)
63
+ print("Chatbot:", response)