Spaces:
Runtime error
Runtime error
Add model and some other thing
Browse files- app.py +36 -20
- backend/query_llm.py +25 -20
- backend/semantic_search.py +25 -4
app.py
CHANGED
@@ -9,10 +9,12 @@ from time import perf_counter
|
|
9 |
import gradio as gr
|
10 |
from jinja2 import Environment, FileSystemLoader
|
11 |
|
12 |
-
from backend.query_llm import generate_hf, generate_openai
|
13 |
from backend.semantic_search import retrieve
|
14 |
import itertools
|
15 |
|
|
|
|
|
16 |
emb_models = ["bge", "minilm"]
|
17 |
splitters = ['ct', 'rct', 'nltk']
|
18 |
chunk_sizes = ["500", "2000"]
|
@@ -56,7 +58,8 @@ def has_balanced_backticks(markdown_str):
|
|
56 |
# If in_code_block is False at the end, all backticks are balanced
|
57 |
return not in_code_block
|
58 |
|
59 |
-
def bot(history,
|
|
|
60 |
num_docs, model_kind, sub_vector_size, chunk_size, splitter_type, all_at_once):
|
61 |
query = history[-1][0]
|
62 |
|
@@ -67,7 +70,10 @@ def bot(history, api_kind,
|
|
67 |
# Retrieve documents relevant to query
|
68 |
document_start = perf_counter()
|
69 |
|
70 |
-
|
|
|
|
|
|
|
71 |
|
72 |
document_time = perf_counter() - document_start
|
73 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
@@ -77,18 +83,18 @@ def bot(history, api_kind,
|
|
77 |
prompt_html = template_html.render(documents=documents, query=query)
|
78 |
|
79 |
|
80 |
-
if
|
81 |
generate_fn = generate_hf
|
82 |
-
elif
|
83 |
generate_fn = generate_openai
|
84 |
else:
|
85 |
-
raise gr.Error(f"API {
|
86 |
|
87 |
|
88 |
history[-1][1] = ""
|
89 |
if all_at_once:
|
90 |
-
for
|
91 |
-
documents_i = retrieve(query, int(num_docs),
|
92 |
prompt_i = template.render(documents=documents_i, query=query)
|
93 |
prompt_html = template_html.render(documents=documents, query=query)
|
94 |
|
@@ -96,13 +102,13 @@ def bot(history, api_kind,
|
|
96 |
prev_hist = history[-1][1]
|
97 |
if not has_balanced_backticks(prev_hist):
|
98 |
prev_hist += "\n```\n"
|
99 |
-
prev_hist += f"\n\n## model {
|
100 |
-
for character in generate_fn(prompt_i, history[:-1]):
|
101 |
hist_chunk = character
|
102 |
history[-1][1] = prev_hist + hist_chunk
|
103 |
yield history, prompt_html
|
104 |
else:
|
105 |
-
for character in generate_fn(prompt, history[:-1]):
|
106 |
history[-1][1] = character
|
107 |
yield history, prompt_html
|
108 |
|
@@ -129,20 +135,30 @@ with gr.Blocks() as demo:
|
|
129 |
|
130 |
|
131 |
with gr.Row():
|
132 |
-
|
133 |
-
model_kind = gr.Radio(choices=emb_models, value="bge", label="embedding model")
|
134 |
sub_vector_size = gr.Radio(choices=sub_vectors, value="32", label="sub-vector size")
|
135 |
-
all_at_once = gr.Checkbox(value=False, label="Run all at once")
|
136 |
-
with gr.Row():
|
137 |
-
api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace", label="Chat model engine")
|
138 |
chunk_size = gr.Radio(choices=chunk_sizes, value="2000", label="chunk size")
|
139 |
splitter_type = gr.Radio(choices=splitters, value="nltk", label="splitter")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
prompt_html = gr.HTML()
|
142 |
# Turn off interactivity while generating if you click
|
143 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
144 |
-
bot, [chatbot,
|
145 |
-
|
|
|
146 |
], [chatbot, prompt_html])
|
147 |
|
148 |
# Turn it back on
|
@@ -150,8 +166,8 @@ with gr.Blocks() as demo:
|
|
150 |
|
151 |
# Turn off interactivity while generating if you hit enter
|
152 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
153 |
-
bot, [chatbot,
|
154 |
-
num_docs,
|
155 |
], [chatbot, prompt_html])
|
156 |
|
157 |
# Turn it back on
|
|
|
9 |
import gradio as gr
|
10 |
from jinja2 import Environment, FileSystemLoader
|
11 |
|
12 |
+
from backend.query_llm import generate_hf, generate_openai, hf_models, openai_models
|
13 |
from backend.semantic_search import retrieve
|
14 |
import itertools
|
15 |
|
16 |
+
inf_models = list(hf_models.keys()) + list(openai_models)
|
17 |
+
|
18 |
emb_models = ["bge", "minilm"]
|
19 |
splitters = ['ct', 'rct', 'nltk']
|
20 |
chunk_sizes = ["500", "2000"]
|
|
|
58 |
# If in_code_block is False at the end, all backticks are balanced
|
59 |
return not in_code_block
|
60 |
|
61 |
+
def bot(history, model_name, oepnai_api_key,
|
62 |
+
reranker_enabled,reranker_kind,num_prerank_docs,
|
63 |
num_docs, model_kind, sub_vector_size, chunk_size, splitter_type, all_at_once):
|
64 |
query = history[-1][0]
|
65 |
|
|
|
70 |
# Retrieve documents relevant to query
|
71 |
document_start = perf_counter()
|
72 |
|
73 |
+
if reranker_enabled:
|
74 |
+
documents = retrieve(query, int(num_docs), model_kind, sub_vector_size, chunk_size, splitter_type,reranker_kind,num_prerank_docs)
|
75 |
+
else:
|
76 |
+
documents = retrieve(query, int(num_docs), model_kind, sub_vector_size, chunk_size, splitter_type)
|
77 |
|
78 |
document_time = perf_counter() - document_start
|
79 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
|
|
83 |
prompt_html = template_html.render(documents=documents, query=query)
|
84 |
|
85 |
|
86 |
+
if model_name == "HuggingFace":
|
87 |
generate_fn = generate_hf
|
88 |
+
elif model_name == "OpenAI":
|
89 |
generate_fn = generate_openai
|
90 |
else:
|
91 |
+
raise gr.Error(f"API {model_name} is not supported")
|
92 |
|
93 |
|
94 |
history[-1][1] = ""
|
95 |
if all_at_once:
|
96 |
+
for emb_model, doc, size, sub_vector in combinations:
|
97 |
+
documents_i = retrieve(query, int(num_docs), emb_model, sub_vector, size, doc)
|
98 |
prompt_i = template.render(documents=documents_i, query=query)
|
99 |
prompt_html = template_html.render(documents=documents, query=query)
|
100 |
|
|
|
102 |
prev_hist = history[-1][1]
|
103 |
if not has_balanced_backticks(prev_hist):
|
104 |
prev_hist += "\n```\n"
|
105 |
+
prev_hist += f"\n\n## model {emb_model}, splitter {doc}, size {size}, sub vector {sub_vector}\n\n"
|
106 |
+
for character in generate_fn(model_name, prompt_i, history[:-1], oepnai_api_key):
|
107 |
hist_chunk = character
|
108 |
history[-1][1] = prev_hist + hist_chunk
|
109 |
yield history, prompt_html
|
110 |
else:
|
111 |
+
for character in generate_fn(model_name, prompt, history[:-1], oepnai_api_key):
|
112 |
history[-1][1] = character
|
113 |
yield history, prompt_html
|
114 |
|
|
|
135 |
|
136 |
|
137 |
with gr.Row():
|
138 |
+
emb_model_kind = gr.Radio(choices=emb_models, value="bge", label="embedding model")
|
|
|
139 |
sub_vector_size = gr.Radio(choices=sub_vectors, value="32", label="sub-vector size")
|
|
|
|
|
|
|
140 |
chunk_size = gr.Radio(choices=chunk_sizes, value="2000", label="chunk size")
|
141 |
splitter_type = gr.Radio(choices=splitters, value="nltk", label="splitter")
|
142 |
+
with gr.Row():
|
143 |
+
reranker_enabled = gr.Checkbox(value=False, label="Reranker enabled")
|
144 |
+
reranker_kind = gr.Radio(choices=emb_models, value="bge", label="Reranker model")
|
145 |
+
num_prerank_docs = gr.Slider(5, 80, label="Number of docs before reranker", step=1, value=20)
|
146 |
+
with gr.Row():
|
147 |
+
num_docs = gr.Slider(1, 20, label="number of docs", step=1, value=4)
|
148 |
+
all_at_once = gr.Checkbox(value=False, label="Run all at once")
|
149 |
+
model_name = gr.Radio(choices=inf_models, value=inf_models[0], label="Chat model")
|
150 |
+
oepnai_api_key = gr.Textbox(
|
151 |
+
show_label=False,
|
152 |
+
placeholder="OpenAI API key",
|
153 |
+
container=False,
|
154 |
+
)
|
155 |
|
156 |
prompt_html = gr.HTML()
|
157 |
# Turn off interactivity while generating if you click
|
158 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
159 |
+
bot, [chatbot, model_name, oepnai_api_key,
|
160 |
+
reranker_enabled,reranker_kind,num_prerank_docs,
|
161 |
+
num_docs, emb_model_kind, sub_vector_size, chunk_size, splitter_type, all_at_once
|
162 |
], [chatbot, prompt_html])
|
163 |
|
164 |
# Turn it back on
|
|
|
166 |
|
167 |
# Turn off interactivity while generating if you hit enter
|
168 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
169 |
+
bot, [chatbot, model_name,
|
170 |
+
num_docs, emb_model_kind, sub_vector_size, chunk_size, splitter_type
|
171 |
], [chatbot, prompt_html])
|
172 |
|
173 |
# Turn it back on
|
backend/query_llm.py
CHANGED
@@ -8,15 +8,19 @@ from huggingface_hub import InferenceClient
|
|
8 |
from transformers import AutoTokenizer
|
9 |
|
10 |
|
11 |
-
OPENAI_KEY = os.getenv("OPENAI_API_KEY")
|
12 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
13 |
-
TOKENIZER = AutoTokenizer.from_pretrained(os.getenv("HF_MODEL"))
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
HF_GENERATE_KWARGS = {
|
22 |
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
|
@@ -34,7 +38,7 @@ OAI_GENERATE_KWARGS = {
|
|
34 |
}
|
35 |
|
36 |
|
37 |
-
def format_prompt(message: str,
|
38 |
"""
|
39 |
Formats the given message using a chat template.
|
40 |
|
@@ -48,15 +52,15 @@ def format_prompt(message: str, api_kind: str):
|
|
48 |
# Create a list of message dictionaries with role and content
|
49 |
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
|
50 |
|
51 |
-
if
|
52 |
return messages
|
53 |
-
elif
|
54 |
-
return
|
55 |
-
|
56 |
-
raise ValueError("
|
57 |
|
58 |
|
59 |
-
def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
60 |
"""
|
61 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
62 |
|
@@ -68,11 +72,11 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
|
68 |
Returns a final string if an error occurs.
|
69 |
"""
|
70 |
|
71 |
-
formatted_prompt = format_prompt(prompt,
|
72 |
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
|
73 |
|
74 |
try:
|
75 |
-
stream =
|
76 |
formatted_prompt,
|
77 |
**HF_GENERATE_KWARGS,
|
78 |
stream=True,
|
@@ -93,7 +97,7 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
|
93 |
raise gr.Error(f"Unhandled Exception: {str(e)}")
|
94 |
|
95 |
|
96 |
-
def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
|
97 |
"""
|
98 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
99 |
|
@@ -104,11 +108,12 @@ def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
|
|
104 |
Generator[str, None, str]: A generator yielding chunks of generated text.
|
105 |
Returns a final string if an error occurs.
|
106 |
"""
|
107 |
-
formatted_prompt = format_prompt(prompt,
|
|
|
108 |
|
109 |
try:
|
110 |
-
stream =
|
111 |
-
model=
|
112 |
messages=formatted_prompt,
|
113 |
**OAI_GENERATE_KWARGS,
|
114 |
stream=True
|
|
|
8 |
from transformers import AutoTokenizer
|
9 |
|
10 |
|
|
|
11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
12 |
|
13 |
+
hf_models = {
|
14 |
+
"mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
|
15 |
+
"mistral-7B 0.1": "mistralai/Mistral-7B-v0.1",
|
16 |
+
"llama 3": "meta-llama/Meta-Llama-3-70B-Instruct",
|
17 |
+
}
|
18 |
+
openai_models = {"gpt-4o","gpt-3.5-turbo-0125"}
|
19 |
+
|
20 |
+
tokenizers = {k: AutoTokenizer.from_pretrained(m) for k,m in hf_models.items()}
|
21 |
+
clients = {k: InferenceClient(
|
22 |
+
m, token=HF_TOKEN
|
23 |
+
) for k,m in hf_models.items()}
|
24 |
|
25 |
HF_GENERATE_KWARGS = {
|
26 |
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
|
|
|
38 |
}
|
39 |
|
40 |
|
41 |
+
def format_prompt(message: str, model: str):
|
42 |
"""
|
43 |
Formats the given message using a chat template.
|
44 |
|
|
|
52 |
# Create a list of message dictionaries with role and content
|
53 |
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
|
54 |
|
55 |
+
if model in openai_models:
|
56 |
return messages
|
57 |
+
elif model in hf_models:
|
58 |
+
return tokenizers[model].apply_chat_template(messages, tokenize=False)
|
59 |
+
else:
|
60 |
+
raise ValueError(f"Model {model} is not supported")
|
61 |
|
62 |
|
63 |
+
def generate_hf(model: str, prompt: str, history: str, _: str) -> Generator[str, None, str]:
|
64 |
"""
|
65 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
66 |
|
|
|
72 |
Returns a final string if an error occurs.
|
73 |
"""
|
74 |
|
75 |
+
formatted_prompt = format_prompt(prompt, model)
|
76 |
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
|
77 |
|
78 |
try:
|
79 |
+
stream = clients[model].text_generation(
|
80 |
formatted_prompt,
|
81 |
**HF_GENERATE_KWARGS,
|
82 |
stream=True,
|
|
|
97 |
raise gr.Error(f"Unhandled Exception: {str(e)}")
|
98 |
|
99 |
|
100 |
+
def generate_openai(model: str, prompt: str, history: str, api_key: str) -> Generator[str, None, str]:
|
101 |
"""
|
102 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
103 |
|
|
|
108 |
Generator[str, None, str]: A generator yielding chunks of generated text.
|
109 |
Returns a final string if an error occurs.
|
110 |
"""
|
111 |
+
formatted_prompt = format_prompt(prompt, model)
|
112 |
+
client = openai.Client(api_key=api_key)
|
113 |
|
114 |
try:
|
115 |
+
stream = client.chat.completions.create(
|
116 |
+
model=model,
|
117 |
messages=formatted_prompt,
|
118 |
**OAI_GENERATE_KWARGS,
|
119 |
stream=True
|
backend/semantic_search.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import lancedb
|
2 |
import os
|
3 |
import gradio as gr
|
4 |
-
from sentence_transformers import SentenceTransformer
|
5 |
|
6 |
|
7 |
db = lancedb.connect(".lancedb")
|
@@ -22,20 +22,41 @@ BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
|
22 |
retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5")
|
23 |
retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
24 |
|
|
|
|
|
25 |
|
26 |
-
|
|
|
27 |
if model_kind == "bge":
|
28 |
query_vec = retriever_bge.encode(query)
|
29 |
else:
|
30 |
query_vec = retriever_minilm.encode(query)
|
31 |
|
|
|
|
|
|
|
32 |
try:
|
33 |
documents = table(
|
34 |
f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}",
|
35 |
-
).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(
|
36 |
documents = [doc[TEXT_COLUMN] for doc in documents]
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
except Exception as e:
|
41 |
raise gr.Error(str(e))
|
|
|
1 |
import lancedb
|
2 |
import os
|
3 |
import gradio as gr
|
4 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
5 |
|
6 |
|
7 |
db = lancedb.connect(".lancedb")
|
|
|
22 |
retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5")
|
23 |
retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
24 |
|
25 |
+
reranker_bge = CrossEncoder("BAAI/bge-reranker-large")
|
26 |
+
reranker_minilm = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
27 |
|
28 |
+
|
29 |
+
def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type, reranker_kind=None, pre_ranker_size=None):
|
30 |
if model_kind == "bge":
|
31 |
query_vec = retriever_bge.encode(query)
|
32 |
else:
|
33 |
query_vec = retriever_minilm.encode(query)
|
34 |
|
35 |
+
if pre_ranker_size is None:
|
36 |
+
pre_ranker_size = k
|
37 |
+
|
38 |
try:
|
39 |
documents = table(
|
40 |
f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}",
|
41 |
+
).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(pre_ranker_size).to_list()
|
42 |
documents = [doc[TEXT_COLUMN] for doc in documents]
|
43 |
|
44 |
+
if reranker_kind is None:
|
45 |
+
return documents
|
46 |
+
# Pair the query with each document for re-ranking
|
47 |
+
query_document_pairs = [(query, text) for text in documents]
|
48 |
+
|
49 |
+
# Score documents using the reranker
|
50 |
+
if reranker_kind == "bge":
|
51 |
+
scores = reranker_bge.predict(query_document_pairs)
|
52 |
+
else:
|
53 |
+
scores = reranker_minilm.predict(query_document_pairs)
|
54 |
+
|
55 |
+
# Aggregate and sort the documents based on the scores
|
56 |
+
scored_documents = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
|
57 |
+
|
58 |
+
# Return the top K documents based on re-ranking
|
59 |
+
return [doc for doc, _ in scored_documents[:k]]
|
60 |
|
61 |
except Exception as e:
|
62 |
raise gr.Error(str(e))
|