Shroogawh24 commited on
Commit
cbb53fd
1 Parent(s): 9b5e34e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import openai
4
+ import pandas as pd
5
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
8
+ from langchain.chains import LLMChain
9
+ from langchain.output_parsers import StrOutputParser
10
+ from langchain.chat_models import ChatOpenAI
11
+
12
+ # Set up the Hugging Face model and embeddings
13
+ model_name = "BAAI/bge-large-en-v1.5"
14
+ model_kwargs = {'device':'cuda'}
15
+ encode_kwargs = {'normalize_embeddings':True}
16
+
17
+ embedding_function = HuggingFaceBgeEmbeddings(
18
+ model_name = model_name,
19
+ model_kwargs = model_kwargs,
20
+ encode_kwargs = encode_kwargs
21
+ )
22
+
23
+ # Set the OpenAI API key
24
+ openai.api_key = os.getenv("OPENAI_API_KEY")
25
+
26
+ # Load the FAISS index using LangChain's FAISS implementation
27
+ db = FAISS.load_local("Faiss", embedding_function, allow_dangerous_deserialization=True)
28
+ parser = StrOutputParser()
29
+
30
+ # Load your data (e.g., a DataFrame)
31
+ df = pd.read_pickle('df_news.pkl')
32
+
33
+ # Search function to retrieve relevant documents
34
+ def search(query):
35
+ query_embedding = embedding_function.embed_query(query).reshape(1, -1).astype('float32')
36
+ D, I = db.similarity_search_with_score(query_embedding, k=10)
37
+ results = []
38
+ for idx in I[0]:
39
+ if idx < 3327: # Adjust this based on your indexing
40
+ doc_index = idx
41
+ results.append({
42
+ 'type': 'metadata',
43
+ 'title': df.iloc[doc_index]['title'],
44
+ 'author': df.iloc[doc_index]['author'],
45
+ 'full_text': df.iloc[doc_index]['full_text'],
46
+ 'source': df.iloc[doc_index]['url']
47
+ })
48
+ else:
49
+ chunk_index = idx - 3327
50
+ metadata = metadata_info[chunk_index]
51
+ doc_index = metadata['index']
52
+ chunk_text = metadata['chunk']
53
+ results.append({
54
+ 'type': 'content',
55
+ 'title': df.iloc[doc_index]['title'],
56
+ 'author': df.iloc[doc_index]['author'],
57
+ 'content': chunk_text,
58
+ 'source': df.iloc[doc_index]['url']
59
+ })
60
+
61
+ return results
62
+
63
+ # Generate an answer based on the retrieved documents
64
+ def generate_answer(query):
65
+ context = search(query)
66
+ context_str = "\n\n".join([f"Title: {doc['title']}\nContent: {doc.get('content', doc.get('full_text', ''))}" for doc in context])
67
+
68
+ prompt = f"""
69
+ Answer the question based on the context below. If you can't answer the question, answer with "I don't know".
70
+ Context: {context_str}
71
+ Question: {query}
72
+ """
73
+
74
+ # Set up the ChatOpenAI model with temperature and other parameters
75
+ chat = ChatOpenAI(
76
+ model="gpt-4",
77
+ temperature=0.2,
78
+ max_tokens=1500,
79
+ api_key=openai.api_key
80
+ )
81
+
82
+ messages = [
83
+ SystemMessagePromptTemplate.from_template("You are a helpful assistant."),
84
+ HumanMessagePromptTemplate.from_template(prompt)
85
+ ]
86
+
87
+ chat_chain = LLMChain(
88
+ llm=chat,
89
+ prompt=ChatPromptTemplate.from_messages(messages)
90
+ )
91
+
92
+ # Get the response from the chat model
93
+ response = chat_chain.run(messages)
94
+ return response.strip()
95
+
96
+ # Gradio chat interface
97
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
98
+ response = generate_answer(message)
99
+ yield response
100
+
101
+ # Gradio demo setup
102
+ demo = gr.ChatInterface(
103
+ respond,
104
+ additional_inputs=[
105
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
106
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
107
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
108
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
109
+ ],
110
+ )
111
+
112
+ if __name__ == "__main__":
113
+ demo.launch()