File size: 6,085 Bytes
d52cca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efb7284
 
d52cca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2db09e8
efb7284
2db09e8
7a16e1c
 
2db09e8
 
d52cca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efb7284
d52cca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

"""
Credit to Derek Thomas, [email protected]
"""

import subprocess

# subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])

import logging
from pathlib import Path
from time import perf_counter

import gradio as gr
from jinja2 import Environment, FileSystemLoader
import numpy as np
from sentence_transformers import CrossEncoder

from backend.query_llm import generate_hf, generate_openai
from backend.semantic_search import table, retriever

VECTOR_COLUMN_NAME = "vector"
TEXT_COLUMN_NAME = "text"

proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))

# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')

# crossEncoder
#cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') 
#cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
# Examples
examples = ['what is social media and what are rules related to it for expenditure monitoring ,
            'how many reports to be submitted by Expenditure observer with annexure names ?',
            ]


def add_text(history, text):
    history = [] if history is None else history
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)


def bot(history, cross_encoder):
    top_rerank = 15
    top_k_rank = 10
    query = history[-1][0]

    if not query:
         gr.Warning("Please submit a non-empty string as a prompt")
         raise ValueError("Empty string was submitted")

    logger.warning('Retrieving documents...')
    # Retrieve documents relevant to query
    document_start = perf_counter()

    query_vec = retriever.encode(query)
    logger.warning(f'Finished query vec')
    doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)

    

    logger.warning(f'Finished search')
    documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
    documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
    logger.warning(f'start cross encoder {len(documents)}')
    # Retrieve documents relevant to query
    query_doc_pair = [[query, doc] for doc in documents]
    if cross_encoder=='MiniLM-L6v2' :
           cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') 
    else:
           cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
    cross_scores = cross_encoder.predict(query_doc_pair)
    sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
    logger.warning(f'Finished cross encoder {len(documents)}')
    
    documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
    logger.warning(f'num documents {len(documents)}')

    document_time = perf_counter() - document_start
    logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')

    # Create Prompt
    prompt = template.render(documents=documents, query=query)
    prompt_html = template_html.render(documents=documents, query=query)

    generate_fn = generate_hf

    history[-1][1] = ""
    for character in generate_fn(prompt, history[:-1]):
        history[-1][1] = character
        print('Final history is ',history)
        yield history, prompt_html


with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
    gr.HTML(value="""<div style="display: flex; align-items: center; justify-content: space-between;">
    <h1 style="color: #008000">NIRVACHANA - <span style="color: #008000">Expenditure Observer AI Assistant</span></h1>
    <img src='logo.png' alt="Chatbot" width="50" height="50" />
    </div>""",elem_id='heading')
    gr.HTML(value="""<p style="font-family: sans-serif; font-size: 16px;">A free chat bot assistant for Expenditure Observers on Compendium on Election Expenditure Monitoring using Open source LLMs. <br> The bot can answer questions in natural language, taking relevant extracts from the ECI document which can be accessed <a href="https://www.eci.gov.in/eci-backend/public/api/download?url=LMAhAK6sOPBp%2FNFF0iRfXbEB1EVSLT41NNLRjYNJJP1KivrUxbfqkDatmHy12e%2Fzk1vx4ptJpQsKYHA87guoLjnPUWtHeZgKtEqs%2FyzfTTYIC0newOHHOjl1rl0u3mJBSIq%2Fi7zDsrcP74v%2FKr8UNw%3D%3D" style="color: #008000; text-decoration: none;">here</a>.</p>""",elem_id='Sub-heading')
    

    chatbot = gr.Chatbot(
            [],
            elem_id="chatbot",
            avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
                           'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
            bubble_full_width=False,
            show_copy_button=True,
            show_share_button=True,
            )

    with gr.Row():
        txt = gr.Textbox(
                scale=3,
                show_label=False,
                placeholder="Enter text and press enter",
                container=False,
                )
        txt_btn = gr.Button(value="Submit text", scale=1)

    cross_encoder = gr.Radio(choices=['MiniLM-L6v2','BGE reranker'], value='BGE reranker',label="Embeddings", info="Choose MiniLM for Speed, BGE reranker for accuracy")

    prompt_html = gr.HTML()
    # Turn off interactivity while generating if you click
    txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
            bot, [chatbot, cross_encoder], [chatbot, prompt_html])

    # Turn it back on
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

    # Turn off interactivity while generating if you hit enter
    txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
            bot, [chatbot, cross_encoder], [chatbot, prompt_html])

    # Turn it back on
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

    # Examples
    gr.Examples(examples, txt)

demo.queue()
demo.launch(debug=True)