Spaces:
Sleeping
Sleeping
import eventlet | |
eventlet.monkey_patch() | |
from dotenv import load_dotenv | |
from flask import Flask, request, render_template | |
from flask_cors import CORS | |
from flask_socketio import SocketIO, emit | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
from langchain.retrievers.document_compressors import EmbeddingsFilter | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers import EnsembleRetriever | |
from langchain_community.vectorstores import FAISS | |
from langchain_groq import ChatGroq | |
from langchain import hub | |
import pickle | |
import os | |
# Load environment variables | |
load_dotenv(".env") | |
USER_AGENT = os.getenv("USER_AGENT") | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
SECRET_KEY = os.getenv("SECRET_KEY") | |
SESSION_ID_DEFAULT = "abc123" | |
# Set environment variables | |
os.environ['USER_AGENT'] = USER_AGENT | |
os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
os.environ["TOKENIZERS_PARALLELISM"] = 'true' | |
# Initialize Flask app and SocketIO with CORS | |
app = Flask(__name__) | |
CORS(app) | |
socketio = SocketIO(app, cors_allowed_origins="*") | |
app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS | |
app.config['SESSION_COOKIE_HTTPONLY'] = True | |
app.config['SECRET_KEY'] = SECRET_KEY | |
embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-multilingual-base", model_kwargs={"trust_remote_code":True}) | |
llm = ChatGroq( | |
model="llama-3.1-8b-instant", | |
temperature=0.0, | |
max_tokens=1024, | |
max_retries=2 | |
) | |
excel_vectorstore = FAISS.load_local(folder_path="./faiss_excel_doc_index", embeddings=embed_model, allow_dangerous_deserialization=True) | |
word_vectorstore = FAISS.load_local(folder_path="./faiss_recursive_split_word_doc_index", embeddings=embed_model, allow_dangerous_deserialization=True) | |
excel_vectorstore.merge_from(word_vectorstore) | |
combined_vectorstore = excel_vectorstore | |
with open('combined_recursive_keyword_retriever.pkl', 'rb') as f: | |
combined_keyword_retriever = pickle.load(f) | |
combined_keyword_retriever.k = 1000 | |
semantic_retriever = combined_vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 100}) | |
# initialize the ensemble retriever | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[combined_keyword_retriever, semantic_retriever], weights=[0.5, 0.5] | |
) | |
embeddings_filter = EmbeddingsFilter(embeddings=embed_model, similarity_threshold=0.4) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=embeddings_filter, base_retriever=semantic_retriever | |
) | |
template = """ | |
User Instructions: | |
You are an Arabic AI Assistant focused on providing clear, concise responses. | |
Always answer truthfully. If the user query is irrelevant to the provided CONTEXT, respond stating the reason. | |
Generate responses in Arabic. Format any English words and numbers appropriately for clarity. | |
Round off numbers with decimal integers to two decimal integers. | |
Use bullet points or numbered lists where applicable for better organization. | |
Provide detailed yet concise answers, covering all important aspects. | |
Remember, responding outside the CONTEXT may lead to the termination of the interaction. | |
CONTEXT: {context} | |
Query: {question} | |
After generating your response, ensure proper formatting and text direction of Arabic and English words/numbers. Return only the AI-generated answer. | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
output_parser = StrOutputParser() | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
rag_chain = ( | |
{"context": compression_retriever.with_config(run_name="Docs") | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| output_parser | |
) | |
# Function to handle WebSocket connection | |
def handle_connect(): | |
emit('connection_response', {'message': 'Connected successfully.'}, room=request.sid) | |
def handle_ping(data): | |
emit('ping_response', {'message': 'Healthy Connection.'}, room=request.sid) | |
# Function to handle WebSocket disconnection | |
def handle_disconnect(): | |
emit('connection_response', {'message': 'Disconnected successfully.'}) | |
# Function to handle WebSocket messages | |
def handle_message(data): | |
question = data.get('question') | |
try: | |
for chunk in rag_chain.stream(question): | |
emit('response', chunk, room=request.sid) | |
except Exception as e: | |
emit('response', {"error": "An error occurred while processing your request."}, room=request.sid) | |
# Home route | |
def index_view(): | |
return render_template('chat.html') | |
# Main function to run the app | |
if __name__ == '__main__': | |
socketio.run(app, debug=True) | |