File size: 6,670 Bytes
e0169c8
 
 
 
dafbeb8
 
e0169c8
 
 
 
 
eeafaaa
10ddae5
e0169c8
 
8b1c859
34b78ab
 
10ddae5
e0169c8
360f505
e0169c8
 
 
 
 
 
360f505
e0169c8
 
8b1c859
 
e0169c8
10ddae5
 
e0169c8
8b1c859
 
eeafaaa
 
8b1c859
 
 
e0169c8
 
 
 
8b1c859
e0169c8
 
 
10ddae5
8b1c859
e0169c8
 
 
34b78ab
e0169c8
360f505
34b78ab
10ddae5
 
 
 
 
 
e0169c8
10ddae5
34b78ab
10ddae5
 
 
8b1c859
 
e0169c8
 
10ddae5
 
e0169c8
34b78ab
 
10ddae5
34b78ab
 
 
10ddae5
 
34b78ab
 
8b1c859
 
eeafaaa
 
34b78ab
 
 
8b1c859
 
e0169c8
8b1c859
e0169c8
34b78ab
10ddae5
 
34b78ab
8b1c859
 
 
10ddae5
 
e0169c8
 
 
 
8b1c859
 
 
 
 
 
 
 
 
10ddae5
8b1c859
 
 
 
 
 
 
 
 
 
 
10ddae5
34b78ab
10ddae5
 
34b78ab
10ddae5
 
 
 
 
 
 
 
 
 
 
34b78ab
 
 
 
 
 
 
 
 
 
 
8b1c859
10ddae5
 
 
 
cfc7185
 
10ddae5
 
 
 
 
8b1c859
 
 
 
 
e0169c8
 
8b1c859
 
 
10ddae5
8b1c859
e0169c8
 
8b1c859
e0169c8
 
8b1c859
10ddae5
e0169c8
 
8b1c859
e0169c8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
Credit to Derek Thomas, [email protected]
"""

# import subprocess
# subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])

import logging
from time import perf_counter

import gradio as gr
import markdown
import lancedb
from jinja2 import Environment, FileSystemLoader

from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
from gradio_app.backend.query_llm import *
from gradio_app.backend.embedders import EmbedderFactory

from settings import *

# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader('gradio_app/templates'))

# Load the templates directly from the environment
context_template = env.get_template('context_template.j2')
context_html_template = env.get_template('context_html_template.j2')

db = lancedb.connect(LANCEDB_DIRECTORY)

# Examples
examples = [
    'What is BERT?',
    'Tell me about GPT',
    'How to use accelerate in google colab?',
    'What is the capital of China?',
    'Why is the sky blue?',
]


def add_text(history, text):
    history = [] if history is None else history
    history = history + [(text, "")]
    return history, gr.Textbox(value="", interactive=False)


def bot(history, llm, cross_enc, chunk, embed):
    history[-1][1] = ""
    query = history[-1][0]

    if not query:
        raise gr.Error("Empty string was submitted")

    logger.info('Retrieving documents...')
    gr.Info('Start documents retrieval ...')
    t = perf_counter()

    table_name = f'{LANCEDB_TABLE_NAME}_{chunk}_{embed}'
    table = db.open_table(table_name)

    embedder = EmbedderFactory.get_embedder(embed)

    query_vec = embedder.embed([query])[0]
    documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
    top_k_rank = TOP_K_RANK if cross_enc is not None else TOP_K_RERANK
    documents = documents.limit(top_k_rank).to_list()
    thresh_dist = thresh_distances[embed]
    thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
    documents = [d for d in documents if d['_distance'] <= thresh_dist]
    documents = [doc[TEXT_COLUMN_NAME] for doc in documents]

    t = perf_counter() - t
    logger.info(f'Finished Retrieving documents in {round(t, 2)} seconds...')

    logger.info('Reranking documents...')
    gr.Info('Start documents reranking ...')
    t = perf_counter()

    documents = rerank_with_cross_encoder(cross_enc, documents, query)

    t = perf_counter() - t
    logger.info(f'Finished Reranking documents in {round(t, 2)} seconds...')

    msg_constructor = get_message_constructor(llm)
    while len(documents) != 0:
        context = context_template.render(documents=documents)
        documents_html = [markdown.markdown(d) for d in documents]
        context_html = context_html_template.render(documents=documents_html)
        messages = msg_constructor(context, history)
        num_tokens = num_tokens_from_messages(messages, 'gpt-3.5-turbo')  # todo for HF, it is approximation
        if num_tokens + 512 < context_lengths[llm]:
            break
        documents.pop()
    else:
        raise gr.Error('Model context length exceeded, reload the page')

    llm_gen = get_llm_generator(llm)
    logger.info('Generating answer...')
    t = perf_counter()
    for part in llm_gen(messages):
        history[-1][1] += part
        yield history, context_html
    else:
        t = perf_counter() - t
        logger.info(f'Finished Generating answer in {round(t, 2)} seconds...')


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            chatbot = gr.Chatbot(
                [],
                elem_id="chatbot",
                avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
                               'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
                bubble_full_width=False,
                show_copy_button=True,
                show_share_button=True,
                height=500,
            )

            with gr.Row():
                input_textbox = gr.Textbox(
                    scale=3,
                    show_label=False,
                    placeholder="Enter text and press enter",
                    container=False,
                )
                txt_btn = gr.Button(value="Submit text", scale=1)

            chunk_name = gr.Radio(
                choices=[
                    "md",
                    "txt",
                ],
                value="md",
                label='Chunking policy'
            )

            embed_name = gr.Radio(
                choices=[
                    "text-embedding-ada-002",
                    "sentence-transformers/all-MiniLM-L6-v2",
                ],
                value="text-embedding-ada-002",
                label='Embedder'
            )

            cross_enc_name = gr.Radio(
                choices=[
                    None,
                    "cross-encoder/ms-marco-TinyBERT-L-2-v2",
                    "cross-encoder/ms-marco-MiniLM-L-12-v2",
                ],
                value=None,
                label='Cross-Encoder'
            )

            llm_name = gr.Radio(
                choices=[
                    "gpt-3.5-turbo",
                    "mistralai/Mistral-7B-Instruct-v0.1",
                    "tiiuae/falcon-180B-chat",
                    # "GeneZC/MiniChat-3B",
                ],
                value="gpt-3.5-turbo",
                label='LLM'
            )

            # Examples
            gr.Examples(examples, input_textbox)

        with gr.Column():
            context_html = gr.HTML()

    # Turn off interactivity while generating if you click
    txt_msg = txt_btn.click(
        add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
    ).then(
        bot, [chatbot, llm_name, cross_enc_name, chunk_name, embed_name], [chatbot, context_html]
    )

    # Turn it back on
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)

    # Turn off interactivity while generating if you hit enter
    txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
        bot, [chatbot, llm_name, cross_enc_name, chunk_name, embed_name], [chatbot, context_html])

    # Turn it back on
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)

demo.queue()
demo.launch(debug=True)