Gopikanth123 commited on
Commit
7e7428c
·
verified ·
1 Parent(s): 34cc5b3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +144 -38
main.py CHANGED
@@ -1,77 +1,183 @@
1
  import os
2
- from flask import Flask, request, jsonify
3
- from llama_index import SimpleDirectoryReader, StorageContext, VectorStoreIndex, load_index_from_storage, ChatPromptTemplate
 
 
 
4
  from huggingface_hub import InferenceClient
5
- from transformers import AutoTokenizer, AutoModel
6
  from deep_translator import GoogleTranslator
 
7
 
8
  # Ensure HF_TOKEN is set
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
  if not HF_TOKEN:
11
  raise ValueError("HF_TOKEN environment variable not set.")
12
 
13
- # Hugging Face model configuration
14
- REPO_ID = "facebook/xlm-roberta-xl" # Use xlm-roberta-xl model
15
- tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
16
- model = AutoModel.from_pretrained(REPO_ID)
17
 
18
- # Flask app
19
- app = Flask(__name__)
20
-
21
- # Directories for storing data
22
  PERSIST_DIR = "db"
23
- PDF_DIRECTORY = "data"
 
 
24
  os.makedirs(PDF_DIRECTORY, exist_ok=True)
25
  os.makedirs(PERSIST_DIR, exist_ok=True)
26
-
27
- # Initialize variables
28
  chat_history = []
 
29
 
30
- # Function to ingest documents
31
  def data_ingestion_from_directory():
 
32
  if os.path.exists(PERSIST_DIR):
33
- os.system(f"rm -rf {PERSIST_DIR}") # Clear previous data
 
 
34
  os.makedirs(PERSIST_DIR, exist_ok=True)
35
 
36
- documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
37
- index = VectorStoreIndex.from_documents(documents)
 
 
 
 
 
38
  index.storage_context.persist(persist_dir=PERSIST_DIR)
39
 
40
- # Function to handle queries
41
  def handle_query(query):
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
43
  index = load_index_from_storage(storage_context)
44
- query_engine = index.as_query_engine()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- chat_prompt = ChatPromptTemplate.from_messages([
47
- ("user", "User asked: {query_str}"),
48
- ("assistant", "Answer: {response}"),
49
- ])
50
 
51
- result = query_engine.query(query, prompt_template=chat_prompt)
52
- return result.response if hasattr(result, 'response') else "No relevant answer found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Route for homepage
55
- @app.route("/")
56
  def index():
57
- return "Welcome to the RAG Application using xlm-roberta-xl!"
58
 
59
  # Route to handle chatbot messages
60
- @app.route("/chat", methods=["POST"])
61
  def chat():
62
  try:
63
  user_message = request.json.get("message")
 
64
  if not user_message:
65
- return jsonify({"response": "Please provide a message!"})
66
 
67
- # Generate response
68
- response = handle_query(user_message)
69
- chat_history.append({"user": user_message, "bot": response})
70
- return jsonify({"response": response})
71
  except Exception as e:
72
  return jsonify({"response": f"An error occurred: {str(e)}"})
73
 
74
- if __name__ == "__main__":
75
- # Ingest data before starting the app
76
- data_ingestion_from_directory()
77
  app.run(debug=True)
 
1
  import os
2
+ import shutil
3
+ from flask import Flask, render_template, request, jsonify
4
+ from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
5
+ from llama_index.llms.huggingface import HuggingFaceInferenceAPI
6
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
  from huggingface_hub import InferenceClient
8
+ from transformers import AutoTokenizer, AutoModel, XLMRobertaXLForMultipleChoice
9
  from deep_translator import GoogleTranslator
10
+ import torch
11
 
12
  # Ensure HF_TOKEN is set
13
  HF_TOKEN = os.getenv("HF_TOKEN")
14
  if not HF_TOKEN:
15
  raise ValueError("HF_TOKEN environment variable not set.")
16
 
17
+ repo_id = "facebook/xlm-roberta-xl"
18
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
19
+ model = XLMRobertaXLForMultipleChoice.from_pretrained(repo_id)
 
20
 
 
 
 
 
21
  PERSIST_DIR = "db"
22
+ PDF_DIRECTORY = 'data'
23
+
24
+ # Ensure directories exist
25
  os.makedirs(PDF_DIRECTORY, exist_ok=True)
26
  os.makedirs(PERSIST_DIR, exist_ok=True)
 
 
27
  chat_history = []
28
+ current_chat_history = []
29
 
 
30
  def data_ingestion_from_directory():
31
+ # Clear previous data by removing the persist directory
32
  if os.path.exists(PERSIST_DIR):
33
+ shutil.rmtree(PERSIST_DIR) # Remove the persist directory and all its contents
34
+
35
+ # Recreate the persist directory after removal
36
  os.makedirs(PERSIST_DIR, exist_ok=True)
37
 
38
+ # Load new documents from the directory
39
+ new_documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
40
+
41
+ # Create a new index with the new documents
42
+ index = VectorStoreIndex.from_documents(new_documents)
43
+
44
+ # Persist the new index
45
  index.storage_context.persist(persist_dir=PERSIST_DIR)
46
 
 
47
  def handle_query(query):
48
+ chat_text_qa_msgs = [
49
+ (
50
+ "user",
51
+ """
52
+ You are the Hotel voice chatbot and your name is hotel helper. Your goal is to provide accurate, professional, and helpful answers to user queries based on the hotel's data. Always ensure your responses are clear and concise. Give response within 10-15 words only. You need to give an answer in the same language used by the user.
53
+ {context_str}
54
+ Question:
55
+ {query_str}
56
+ """
57
+ )
58
+ ]
59
+ text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
60
+
61
  storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
62
  index = load_index_from_storage(storage_context)
63
+ context_str = ""
64
+ for past_query, response in reversed(current_chat_history):
65
+ if past_query.strip():
66
+ context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
67
+
68
+ query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
69
+ print(query)
70
+ answer = query_engine.query(query)
71
+
72
+ if hasattr(answer, 'response'):
73
+ response = answer.response
74
+ elif isinstance(answer, dict) and 'response' in answer:
75
+ response = answer['response']
76
+ else:
77
+ response = "Sorry, I couldn't find an answer."
78
+ current_chat_history.append((query, response))
79
+ return response
80
 
81
+ def evaluate_model(prompt, choice0, choice1):
82
+ labels = torch.tensor(0).unsqueeze(0) # choice0 is correct, batch size 1
83
+ encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
84
+ outputs = model(**{k: v.unsqueeze(0) for k, v in encoding.items()}, labels=labels) # batch size is 1
85
 
86
+ # the linear classifier still needs to be trained
87
+ loss = outputs.loss
88
+ logits = outputs.logits
89
+ return loss, logits
90
+
91
+ app = Flask(__name__)
92
+
93
+ # Data ingestion
94
+ data_ingestion_from_directory()
95
+
96
+ # Generate Response
97
+ def generate_response(query, language):
98
+ try:
99
+ # Call the handle_query function to get the response
100
+ bot_response = handle_query(query)
101
+
102
+ # Map of supported languages
103
+ supported_languages = {
104
+ "hindi": "hi",
105
+ "bengali": "bn",
106
+ "telugu": "te",
107
+ "marathi": "mr",
108
+ "tamil": "ta",
109
+ "gujarati": "gu",
110
+ "kannada": "kn",
111
+ "malayalam": "ml",
112
+ "punjabi": "pa",
113
+ "odia": "or",
114
+ "urdu": "ur",
115
+ "assamese": "as",
116
+ "sanskrit": "sa",
117
+ "arabic": "ar",
118
+ "australian": "en-AU",
119
+ "bangla-india": "bn-IN",
120
+ "chinese": "zh-CN",
121
+ "dutch": "nl",
122
+ "french": "fr",
123
+ "filipino": "tl",
124
+ "greek": "el",
125
+ "indonesian": "id",
126
+ "italian": "it",
127
+ "japanese": "ja",
128
+ "korean": "ko",
129
+ "latin": "la",
130
+ "nepali": "ne",
131
+ "portuguese": "pt",
132
+ "romanian": "ro",
133
+ "russian": "ru",
134
+ "spanish": "es",
135
+ "swedish": "sv",
136
+ "thai": "th",
137
+ "ukrainian": "uk",
138
+ "turkish": "tr"
139
+ }
140
+
141
+ # Initialize the translated text
142
+ translated_text = bot_response
143
+
144
+ # Translate only if the language is supported and not English
145
+ try:
146
+ if language in supported_languages:
147
+ target_lang = supported_languages[language]
148
+ translated_text = GoogleTranslator(source='en', target=target_lang).translate(bot_response)
149
+ print(translated_text)
150
+ else:
151
+ print(f"Unsupported language: {language}")
152
+ except Exception as e:
153
+ # Handle translation errors
154
+ print(f"Translation error: {e}")
155
+ translated_text = "Sorry, I couldn't translate the response."
156
+
157
+ # Append to chat history
158
+ chat_history.append((query, translated_text))
159
+ return translated_text
160
+ except Exception as e:
161
+ return f"Error fetching the response: {str(e)}"
162
 
163
+ # Route for the homepage
164
+ @app.route('/')
165
  def index():
166
+ return render_template('index.html')
167
 
168
  # Route to handle chatbot messages
169
+ @app.route('/chat', methods=['POST'])
170
  def chat():
171
  try:
172
  user_message = request.json.get("message")
173
+ language = request.json.get("language")
174
  if not user_message:
175
+ return jsonify({"response": "Please say something!"})
176
 
177
+ bot_response = generate_response(user_message, language)
178
+ return jsonify({"response": bot_response})
 
 
179
  except Exception as e:
180
  return jsonify({"response": f"An error occurred: {str(e)}"})
181
 
182
+ if __name__ == '__main__':
 
 
183
  app.run(debug=True)