hatim00101 commited on
Commit
7378948
1 Parent(s): 5b49563

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ generator = pipeline("text-generation", model="tiiuae/falcon-rw-1b")
3
+ from sentence_transformers import SentenceTransformer , CrossEncoder
4
+ from transformers import pipeline
5
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
6
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
7
+ import numpy as np
8
+ import faiss
9
+ import pickle
10
+ from rank_bm25 import BM25Okapi
11
+ import gradio as gr
12
+
13
+ # Load embeddings and FAISS index
14
+ with open("Assemesment5_day4.my_faiss_embeddings.pkl", "rb") as f:
15
+ embeddings = pickle.load(f)
16
+
17
+ faiss_index = faiss.read_index("my_faiss_index.faiss")
18
+
19
+ # Load chunks
20
+ with open('chunks.pkl', 'rb') as f:
21
+ chunks = pickle.load(f)
22
+
23
+ bm25 = BM25Okapi([chunk['text'].split() for chunk in chunks])
24
+
25
+
26
+ def hybrid_search(query, top_k=5):
27
+ query_tokens = query.split()
28
+
29
+ # BM25 retrieval
30
+ bm25_scores = bm25.get_scores(query_tokens)
31
+ top_bm25_indices = np.argsort(bm25_scores)[::-1][:top_k]
32
+
33
+ # FAISS retrieval
34
+ query_embedding = embedder.encode([query])
35
+ distances, faiss_indices = faiss_index.search(query_embedding, top_k)
36
+
37
+ # Combine results
38
+ combined_indices = np.unique(np.concatenate((top_bm25_indices, faiss_indices[0])), axis=0)[:top_k]
39
+
40
+ combined_chunks = [chunks[i] for i in combined_indices]
41
+ inputs = [(query, chunk['text']) for chunk in combined_chunks]
42
+
43
+ # Cross-encoder reranking
44
+ scores = cross_encoder.predict(inputs)
45
+ reranked_chunks = [chunk for _, chunk in sorted(zip(scores, combined_chunks), reverse=True)]
46
+
47
+ return reranked_chunks
48
+
49
+ def two_stage_rag_search(query, top_k=5):
50
+ results = hybrid_search(query, top_k)
51
+
52
+ context = "\n\n".join([chunk['text'] for chunk in results])
53
+
54
+ extraction_prompt = (
55
+ f"Given the following context, extract the most relevant passage that answers the question.\n\n"
56
+ f"Context:\n{context}\n\n"
57
+ f"Question: {query}\n\n"
58
+ f"Relevant Passage:\n"
59
+ )
60
+
61
+ extraction_response = generator(extraction_prompt, max_length=1000, num_return_sequences=1)
62
+ relevant_passage = extraction_response[0]['generated_text'].strip()
63
+
64
+ answer_prompt = (
65
+ f"Based on the passage below, generate a detailed and thoughtful answer to the question.\n\n"
66
+ f"Relevant Passage: {relevant_passage}\n\n"
67
+ f"Question: {query}\n\n"
68
+ f"Answer:\n"
69
+ f"Format your response as follows:\n"
70
+ f"Metadata:\n"
71
+ f"Author: 'author'\n"
72
+ f"Title: 'title'\n"
73
+ f"Date: 'date'\n"
74
+ f"Description: 'description'\n\n"
75
+ f"Content or text:\n"
76
+ f"{relevant_passage}"
77
+ )
78
+
79
+ answer_response = generator(answer_prompt, max_length=1500, num_return_sequences=1)
80
+ final_answer = answer_response[0]['generated_text'].strip()
81
+
82
+ return final_answer
83
+
84
+ def gradio_interface(query, feedback):
85
+ results = hybrid_search(query, top_k=5)
86
+
87
+ # Convert results to a format suitable for Gradio
88
+ result_texts = "\n\n".join([f"Text: {chunk['text']}\nMetadata: {chunk['metadata']}" for chunk in results])
89
+
90
+ # Provide a detailed answer
91
+ detailed_answer = two_stage_rag_search(query)
92
+
93
+ return result_texts, detailed_answer
94
+
95
+ interface = gr.Interface(
96
+ fn=gradio_interface,
97
+ inputs=[
98
+ gr.Textbox(lines=2, placeholder="Enter your query here..."),
99
+ gr.Dropdown(choices=["positive", "negative"], label="Feedback"),
100
+ ],
101
+ outputs=[
102
+ gr.Textbox(lines=20, placeholder="The search results will be displayed here..."),
103
+ gr.Textbox(lines=20, placeholder="The detailed answer will be displayed here...")
104
+ ],
105
+ title="Advanced RAG Search Engine",
106
+ description="Test the advanced RAG search engine with hybrid search."
107
+ )
108
+
109
+ interface.launch()