File size: 3,711 Bytes
7378948
f3ec75b
7378948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e718914
87ae23a
7378948
 
 
 
 
e718914
87ae23a
7378948
 
87ae23a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from transformers import  pipeline
generator = pipeline("text-generation", model="tiiuae/falcon-7B")
from sentence_transformers import SentenceTransformer , CrossEncoder
from transformers import pipeline
embedder = SentenceTransformer('all-MiniLM-L6-v2')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
import numpy as np
import faiss
import pickle
from rank_bm25 import BM25Okapi
import gradio as gr

# Load embeddings and FAISS index
with open("Assemesment5_day4.my_faiss_embeddings.pkl", "rb") as f:
    embeddings = pickle.load(f)

faiss_index = faiss.read_index("my_faiss_index.faiss")

# Load chunks
with open('chunks.pkl', 'rb') as f:
    chunks = pickle.load(f)

bm25 = BM25Okapi([chunk['text'].split() for chunk in chunks])


def hybrid_search(query, top_k=5):
    query_tokens = query.split()
    
    # BM25 retrieval
    bm25_scores = bm25.get_scores(query_tokens)
    top_bm25_indices = np.argsort(bm25_scores)[::-1][:top_k]
    
    # FAISS retrieval
    query_embedding = embedder.encode([query])
    distances, faiss_indices = faiss_index.search(query_embedding, top_k)
    
    # Combine results
    combined_indices = np.unique(np.concatenate((top_bm25_indices, faiss_indices[0])), axis=0)[:top_k]
    
    combined_chunks = [chunks[i] for i in combined_indices]
    inputs = [(query, chunk['text']) for chunk in combined_chunks]
    
    # Cross-encoder reranking
    scores = cross_encoder.predict(inputs)
    reranked_chunks = [chunk for _, chunk in sorted(zip(scores, combined_chunks), reverse=True)]
    
    return reranked_chunks

def two_stage_rag_search(query, top_k=5):
    results = hybrid_search(query, top_k)
    
    context = "\n\n".join([chunk['text'] for chunk in results])
    
    extraction_prompt = (
        f"Given the following context, extract the most relevant passage that answers the question.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {query}\n\n"
        f"Relevant Passage:\n"
    )
    
    extraction_response = generator(extraction_prompt, max_length=1000, num_return_sequences=1)
    relevant_passage = extraction_response[0]['generated_text'].strip()

    answer_prompt = (
        f"Based on the passage below, generate a detailed and thoughtful answer to the question.\n\n"
        f"Relevant Passage: {relevant_passage}\n\n"
        f"Question: {query}\n\n"
        f"Answer:\n"
        f"Format your response as follows:\n"
        f"Metadata:\n"
        f"Author: 'author'\n"
        f"Title: 'title'\n"
        f"Date: 'date'\n"
        f"Description: 'description'\n\n"
        f"Content or text:\n"
        f"{relevant_passage}"
    )
    
    answer_response = generator(answer_prompt, max_length=1500, num_return_sequences=1)
    final_answer = answer_response[0]['generated_text'].strip()
    
    return final_answer

def gradio_interface(query, feedback):
    results = hybrid_search(query, top_k=5)
    
    # Convert results to a format suitable for Gradio
    result_texts = "\n\n".join([f"Text: {chunk['text']}\nMetadata: {chunk['metadata']}" for chunk in results])
    
    # Provide a detailed answer
    detailed_answer = two_stage_rag_search(query)
    
    return result_texts, detailed_answer

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your query here...")
        # gr.Dropdown(choices=["positive", "negative"], label="Feedback"),
    ],
    outputs=[
        gr.Textbox(lines=20, placeholder="The search results will be displayed here..."),
        gr.Textbox(lines=20, placeholder="The detailed answer will be displayed here...")
    ],
    title="News share engine_zz",
    description="."
)

interface.launch(share=True)