Gopikanth123 commited on
Commit
f6fbd63
·
verified ·
1 Parent(s): 6fe1f69

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +134 -202
main.py CHANGED
@@ -1,228 +1,160 @@
1
  import os
2
  import shutil
 
3
  from flask import Flask, render_template, request, jsonify
4
- from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
5
- from llama_index.llms.huggingface import HuggingFaceInferenceAPI
6
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
- from huggingface_hub import InferenceClient
8
- from transformers import AutoTokenizer, AutoModel
9
- from deep_translator import GoogleTranslator
10
-
11
-
12
- # Ensure HF_TOKEN is set
13
- HF_TOKEN = os.getenv("HF_TOKEN")
14
- if not HF_TOKEN:
15
- raise ValueError("HF_TOKEN environment variable not set.")
16
-
17
- repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
18
- llm_client = InferenceClient(
19
- model=repo_id,
20
- token=HF_TOKEN,
21
- )
22
-
23
- # Configure Llama index settings
24
- Settings.llm = HuggingFaceInferenceAPI(
25
- model_name=repo_id,
26
- tokenizer_name=repo_id,
27
- context_window=3000,
28
- token=HF_TOKEN,
29
- max_new_tokens=512,
30
- generate_kwargs={"temperature": 0.1},
31
- )
32
- # Settings.embed_model = HuggingFaceEmbedding(
33
- # model_name="BAAI/bge-small-en-v1.5"
34
- # )
35
- # Replace the embedding model with XLM-R
36
- # Settings.embed_model = HuggingFaceEmbedding(
37
- # model_name="xlm-roberta-base" # XLM-RoBERTa model for multilingual support
38
- # )
39
- Settings.embed_model = HuggingFaceEmbedding(
40
- model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
41
- )
42
-
43
- # Configure tokenizer and model if required
44
- tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
45
- model = AutoModel.from_pretrained("xlm-roberta-base")
46
 
 
47
  PERSIST_DIR = "db"
48
  PDF_DIRECTORY = 'data'
49
-
50
- # Ensure directories exist
51
  os.makedirs(PDF_DIRECTORY, exist_ok=True)
52
  os.makedirs(PERSIST_DIR, exist_ok=True)
53
- chat_history = []
54
- current_chat_history = []
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def data_ingestion_from_directory():
57
  # Clear previous data by removing the persist directory
58
  if os.path.exists(PERSIST_DIR):
59
  shutil.rmtree(PERSIST_DIR) # Remove the persist directory and all its contents
60
-
61
  # Recreate the persist directory after removal
62
  os.makedirs(PERSIST_DIR, exist_ok=True)
63
-
64
- # Load new documents from the directory
65
- new_documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
66
-
67
- # Create a new index with the new documents
68
- index = VectorStoreIndex.from_documents(new_documents)
69
-
70
- # Persist the new index
71
- index.storage_context.persist(persist_dir=PERSIST_DIR)
72
-
73
- # def handle_query(query):
74
- # context_str = ""
75
-
76
- # # Build context from current chat history
77
- # for past_query, response in reversed(current_chat_history):
78
- # if past_query.strip():
79
- # context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
80
-
81
- # chat_text_qa_msgs = [
82
- # (
83
- # "user",
84
- # """
85
- # You are the Taj Hotel voice chatbot and your name is Taj hotel helper. Your goal is to provide accurate, professional, and helpful answers to user queries based on the Taj hotel data. Always ensure your responses are clear and concise. Give response within 10-15 words only. You need to give an answer in the same language used by the user.
86
- # {context_str}
87
- # Question:
88
- # {query_str}
89
- # """
90
- # )
91
- # ]
92
-
93
-
94
-
95
- # text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
96
-
97
- # storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
98
- # index = load_index_from_storage(storage_context)
99
- # # context_str = ""
100
-
101
- # # # Build context from current chat history
102
- # # for past_query, response in reversed(current_chat_history):
103
- # # if past_query.strip():
104
- # # context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
105
-
106
- # query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
107
- # print(f"Querying: {query}")
108
- # answer = query_engine.query(query)
109
-
110
- # # Extracting the response
111
- # if hasattr(answer, 'response'):
112
- # response = answer.response
113
- # elif isinstance(answer, dict) and 'response' in answer:
114
- # response = answer['response']
115
- # else:
116
- # response = "I'm sorry, I couldn't find an answer to that."
117
-
118
- # # Append to chat history
119
- # current_chat_history.append((query, response))
120
- # return response
121
- def handle_query(query):
122
- chat_text_qa_msgs = [
123
- (
124
- "user",
125
- """
126
- You are the Hotel voice chatbot and your name is hotel helper. Your goal is to provide accurate, professional, and helpful answers to user queries based on the hotel's data. Always ensure your responses are clear and concise. Give response within 10-15 words only. You need to give an answer in the same language used by the user.
127
- {context_str}
128
- Question:
129
- {query_str}
130
- """
131
- )
132
- ]
133
- text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
134
-
135
- storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
136
- index = load_index_from_storage(storage_context)
137
- context_str = ""
138
- for past_query, response in reversed(current_chat_history):
139
- if past_query.strip():
140
- context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
141
-
142
- query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
143
- print(query)
144
- answer = query_engine.query(query)
145
-
146
- if hasattr(answer, 'response'):
147
- response = answer.response
148
- elif isinstance(answer, dict) and 'response' in answer:
149
- response = answer['response']
150
- else:
151
- response = "Sorry, I couldn't find an answer."
152
- current_chat_history.append((query, response))
153
- return response
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  app = Flask(__name__)
156
 
157
  # Data ingestion
158
  data_ingestion_from_directory()
159
 
160
- # Generate Response
161
  def generate_response(query, language):
162
  try:
163
  # Call the handle_query function to get the response
164
- bot_response = handle_query(query)
165
-
166
- # Map of supported languages
167
- supported_languages = {
168
- "hindi": "hi",
169
- "bengali": "bn",
170
- "telugu": "te",
171
- "marathi": "mr",
172
- "tamil": "ta",
173
- "gujarati": "gu",
174
- "kannada": "kn",
175
- "malayalam": "ml",
176
- "punjabi": "pa",
177
- "odia": "or",
178
- "urdu": "ur",
179
- "assamese": "as",
180
- "sanskrit": "sa",
181
- "arabic": "ar",
182
- "australian": "en-AU",
183
- "bangla-india": "bn-IN",
184
- "chinese": "zh-CN",
185
- "dutch": "nl",
186
- "french": "fr",
187
- "filipino": "tl",
188
- "greek": "el",
189
- "indonesian": "id",
190
- "italian": "it",
191
- "japanese": "ja",
192
- "korean": "ko",
193
- "latin": "la",
194
- "nepali": "ne",
195
- "portuguese": "pt",
196
- "romanian": "ro",
197
- "russian": "ru",
198
- "spanish": "es",
199
- "swedish": "sv",
200
- "thai": "th",
201
- "ukrainian": "uk",
202
- "turkish": "tr"
203
- }
204
-
205
- # Initialize the translated text
206
- translated_text = bot_response
207
-
208
- # Translate only if the language is supported and not English
209
- try:
210
- if language in supported_languages:
211
- target_lang = supported_languages[language]
212
- translated_text = GoogleTranslator(source='en', target=target_lang).translate(bot_response)
213
- print(translated_text)
214
- else:
215
- print(f"Unsupported language: {language}")
216
- except Exception as e:
217
- # Handle translation errors
218
- print(f"Translation error: {e}")
219
- translated_text = "Sorry, I couldn't translate the response."
220
-
221
- # Append to chat history
222
- chat_history.append((query, translated_text))
223
  return translated_text
224
  except Exception as e:
225
- return f"Error fetching the response: {str(e)}"
226
 
227
  # Route for the homepage
228
  @app.route('/')
@@ -233,12 +165,12 @@ def index():
233
  @app.route('/chat', methods=['POST'])
234
  def chat():
235
  try:
236
- user_message = request.json.get("message")
237
- language = request.json.get("language")
238
  if not user_message:
239
  return jsonify({"response": "Please say something!"})
240
 
241
- bot_response = generate_response(user_message,language)
242
  return jsonify({"response": bot_response})
243
  except Exception as e:
244
  return jsonify({"response": f"An error occurred: {str(e)}"})
 
1
  import os
2
  import shutil
3
+ import torch
4
  from flask import Flask, render_template, request, jsonify
5
+ from whoosh.index import create_in
6
+ from whoosh.fields import Schema, TEXT
7
+ from whoosh.qparser import QueryParser
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from deep_translator import GoogleTranslator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Ensure the necessary directories exist
12
  PERSIST_DIR = "db"
13
  PDF_DIRECTORY = 'data'
 
 
14
  os.makedirs(PDF_DIRECTORY, exist_ok=True)
15
  os.makedirs(PERSIST_DIR, exist_ok=True)
 
 
16
 
17
+ # Load the XLM-R tokenizer and model
18
+ tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
19
+ model = AutoModel.from_pretrained("xlm-roberta-base")
20
+
21
+ # Setup Whoosh schema for indexing
22
+ schema = Schema(title=TEXT(stored=True), content=TEXT(stored=True))
23
+
24
+ # Create an index in the persist directory
25
+ if not os.path.exists(PERSIST_DIR):
26
+ os.mkdir(PERSIST_DIR)
27
+ index = create_in(PERSIST_DIR, schema)
28
+
29
+ # Function to load documents from a directory
30
+ def load_documents():
31
+ documents = []
32
+ for filename in os.listdir(PDF_DIRECTORY):
33
+ if filename.endswith(".txt"): # Assuming documents are in .txt format
34
+ with open(os.path.join(PDF_DIRECTORY, filename), 'r', encoding='utf-8') as file:
35
+ content = file.read()
36
+ documents.append({'title': filename, 'content': content})
37
+ return documents
38
+
39
+ # Function to index documents
40
+ def index_documents(documents):
41
+ writer = index.writer()
42
+ for doc in documents:
43
+ writer.add_document(title=doc['title'], content=doc['content'])
44
+ writer.commit()
45
+
46
+ # Data ingestion from the directory
47
  def data_ingestion_from_directory():
48
  # Clear previous data by removing the persist directory
49
  if os.path.exists(PERSIST_DIR):
50
  shutil.rmtree(PERSIST_DIR) # Remove the persist directory and all its contents
51
+
52
  # Recreate the persist directory after removal
53
  os.makedirs(PERSIST_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Load new documents from the directory
56
+ new_documents = load_documents()
57
+
58
+ # Index the new documents
59
+ index_documents(new_documents)
60
+
61
+ # Function to retrieve documents based on a query
62
+ def retrieve_documents(query):
63
+ with index.searcher() as searcher:
64
+ query_parser = QueryParser("content", index.schema)
65
+ query_object = query_parser.parse(query)
66
+ results = searcher.search(query_object)
67
+ return [(result['title'], result['content']) for result in results]
68
+
69
+ # Function to generate embeddings
70
+ def get_embeddings(text):
71
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
72
+ with torch.no_grad():
73
+ outputs = model(**inputs)
74
+ embeddings = outputs.last_hidden_state.mean(dim=1) # Average pooling
75
+ return embeddings.squeeze().numpy()
76
+
77
+ # Function to handle queries and generate responses
78
+ def handle_query(query):
79
+ retrieved_docs = retrieve_documents(query)
80
+
81
+ if not retrieved_docs:
82
+ return "Sorry, I couldn't find an answer."
83
+
84
+ # Construct a response using the retrieved documents
85
+ response = "Here are some insights based on your query:\n" + "\n".join(
86
+ [f"Title: {title}\nContent: {content[:100]}..." for title, content in retrieved_docs]
87
+ )
88
+ return response
89
+
90
+ # Initialize Flask app
91
  app = Flask(__name__)
92
 
93
  # Data ingestion
94
  data_ingestion_from_directory()
95
 
96
+ # Generate Response
97
  def generate_response(query, language):
98
  try:
99
  # Call the handle_query function to get the response
100
+ bot_response = handle_query(query)
101
+
102
+ # Map of supported languages
103
+ supported_languages = {
104
+ "hindi": "hi",
105
+ "bengali": "bn",
106
+ "telugu": "te",
107
+ "marathi": "mr",
108
+ "tamil": "ta",
109
+ "gujarati": "gu",
110
+ "kannada": "kn",
111
+ "malayalam": "ml",
112
+ "punjabi": "pa",
113
+ "odia": "or",
114
+ "urdu": "ur",
115
+ "assamese": "as",
116
+ "sanskrit": "sa",
117
+ "arabic": "ar",
118
+ "australian": "en-AU",
119
+ "bangla-india": "bn-IN",
120
+ "chinese": "zh-CN",
121
+ "dutch": "nl",
122
+ "french": "fr",
123
+ "filipino": "tl",
124
+ "greek": "el",
125
+ "indonesian": "id",
126
+ "italian": "it",
127
+ "japanese": "ja",
128
+ "korean": "ko",
129
+ "latin": "la",
130
+ "nepali": "ne",
131
+ "portuguese": "pt",
132
+ "romanian": "ro",
133
+ "russian": "ru",
134
+ "spanish": "es",
135
+ "swedish": "sv",
136
+ "thai": "th",
137
+ "ukrainian": "uk",
138
+ "turkish": "tr"
139
+ }
140
+
141
+ # Initialize the translated text
142
+ translated_text = bot_response
143
+
144
+ # Translate only if the language is supported and not English
145
+ try:
146
+ if language in supported_languages:
147
+ target_lang = supported_languages[language]
148
+ translated_text = GoogleTranslator(source='auto', target=target_lang).translate(bot_response)
149
+ else:
150
+ print(f"Unsupported language: {language}")
151
+ except Exception as e:
152
+ print(f"Translation error: {e}")
153
+ translated_text = "Sorry, I couldn't translate the response."
154
+
 
 
 
 
155
  return translated_text
156
  except Exception as e:
157
+ return f"Error fetching the response: {str(e)}"
158
 
159
  # Route for the homepage
160
  @app.route('/')
 
165
  @app.route('/chat', methods=['POST'])
166
  def chat():
167
  try:
168
+ user_message = request.json.get("message")
169
+ language = request.json.get("language")
170
  if not user_message:
171
  return jsonify({"response": "Please say something!"})
172
 
173
+ bot_response = generate_response(user_message, language)
174
  return jsonify({"response": bot_response})
175
  except Exception as e:
176
  return jsonify({"response": f"An error occurred: {str(e)}"})