complynx commited on
Commit
7588eb3
Β·
1 Parent(s): dbf2edc

Add model and some other thing

Browse files
Files changed (3) hide show
  1. app.py +36 -20
  2. backend/query_llm.py +25 -20
  3. backend/semantic_search.py +25 -4
app.py CHANGED
@@ -9,10 +9,12 @@ from time import perf_counter
9
  import gradio as gr
10
  from jinja2 import Environment, FileSystemLoader
11
 
12
- from backend.query_llm import generate_hf, generate_openai
13
  from backend.semantic_search import retrieve
14
  import itertools
15
 
 
 
16
  emb_models = ["bge", "minilm"]
17
  splitters = ['ct', 'rct', 'nltk']
18
  chunk_sizes = ["500", "2000"]
@@ -56,7 +58,8 @@ def has_balanced_backticks(markdown_str):
56
  # If in_code_block is False at the end, all backticks are balanced
57
  return not in_code_block
58
 
59
- def bot(history, api_kind,
 
60
  num_docs, model_kind, sub_vector_size, chunk_size, splitter_type, all_at_once):
61
  query = history[-1][0]
62
 
@@ -67,7 +70,10 @@ def bot(history, api_kind,
67
  # Retrieve documents relevant to query
68
  document_start = perf_counter()
69
 
70
- documents = retrieve(query, int(num_docs), model_kind, sub_vector_size, chunk_size, splitter_type)
 
 
 
71
 
72
  document_time = perf_counter() - document_start
73
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
@@ -77,18 +83,18 @@ def bot(history, api_kind,
77
  prompt_html = template_html.render(documents=documents, query=query)
78
 
79
 
80
- if api_kind == "HuggingFace":
81
  generate_fn = generate_hf
82
- elif api_kind == "OpenAI":
83
  generate_fn = generate_openai
84
  else:
85
- raise gr.Error(f"API {api_kind} is not supported")
86
 
87
 
88
  history[-1][1] = ""
89
  if all_at_once:
90
- for model_name, doc, size, sub_vector in combinations:
91
- documents_i = retrieve(query, int(num_docs), model_name, sub_vector, size, doc)
92
  prompt_i = template.render(documents=documents_i, query=query)
93
  prompt_html = template_html.render(documents=documents, query=query)
94
 
@@ -96,13 +102,13 @@ def bot(history, api_kind,
96
  prev_hist = history[-1][1]
97
  if not has_balanced_backticks(prev_hist):
98
  prev_hist += "\n```\n"
99
- prev_hist += f"\n\n## model {model_name}, splitter {doc}, size {size}, sub vector {sub_vector}\n\n"
100
- for character in generate_fn(prompt_i, history[:-1]):
101
  hist_chunk = character
102
  history[-1][1] = prev_hist + hist_chunk
103
  yield history, prompt_html
104
  else:
105
- for character in generate_fn(prompt, history[:-1]):
106
  history[-1][1] = character
107
  yield history, prompt_html
108
 
@@ -129,20 +135,30 @@ with gr.Blocks() as demo:
129
 
130
 
131
  with gr.Row():
132
- num_docs = gr.Slider(1, 20, label="number of docs", step=1, value=4)
133
- model_kind = gr.Radio(choices=emb_models, value="bge", label="embedding model")
134
  sub_vector_size = gr.Radio(choices=sub_vectors, value="32", label="sub-vector size")
135
- all_at_once = gr.Checkbox(value=False, label="Run all at once")
136
- with gr.Row():
137
- api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace", label="Chat model engine")
138
  chunk_size = gr.Radio(choices=chunk_sizes, value="2000", label="chunk size")
139
  splitter_type = gr.Radio(choices=splitters, value="nltk", label="splitter")
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  prompt_html = gr.HTML()
142
  # Turn off interactivity while generating if you click
143
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
144
- bot, [chatbot, api_kind,
145
- num_docs, model_kind, sub_vector_size, chunk_size, splitter_type, all_at_once
 
146
  ], [chatbot, prompt_html])
147
 
148
  # Turn it back on
@@ -150,8 +166,8 @@ with gr.Blocks() as demo:
150
 
151
  # Turn off interactivity while generating if you hit enter
152
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
153
- bot, [chatbot, api_kind,
154
- num_docs, model_kind, sub_vector_size, chunk_size, splitter_type
155
  ], [chatbot, prompt_html])
156
 
157
  # Turn it back on
 
9
  import gradio as gr
10
  from jinja2 import Environment, FileSystemLoader
11
 
12
+ from backend.query_llm import generate_hf, generate_openai, hf_models, openai_models
13
  from backend.semantic_search import retrieve
14
  import itertools
15
 
16
+ inf_models = list(hf_models.keys()) + list(openai_models)
17
+
18
  emb_models = ["bge", "minilm"]
19
  splitters = ['ct', 'rct', 'nltk']
20
  chunk_sizes = ["500", "2000"]
 
58
  # If in_code_block is False at the end, all backticks are balanced
59
  return not in_code_block
60
 
61
+ def bot(history, model_name, oepnai_api_key,
62
+ reranker_enabled,reranker_kind,num_prerank_docs,
63
  num_docs, model_kind, sub_vector_size, chunk_size, splitter_type, all_at_once):
64
  query = history[-1][0]
65
 
 
70
  # Retrieve documents relevant to query
71
  document_start = perf_counter()
72
 
73
+ if reranker_enabled:
74
+ documents = retrieve(query, int(num_docs), model_kind, sub_vector_size, chunk_size, splitter_type,reranker_kind,num_prerank_docs)
75
+ else:
76
+ documents = retrieve(query, int(num_docs), model_kind, sub_vector_size, chunk_size, splitter_type)
77
 
78
  document_time = perf_counter() - document_start
79
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
 
83
  prompt_html = template_html.render(documents=documents, query=query)
84
 
85
 
86
+ if model_name == "HuggingFace":
87
  generate_fn = generate_hf
88
+ elif model_name == "OpenAI":
89
  generate_fn = generate_openai
90
  else:
91
+ raise gr.Error(f"API {model_name} is not supported")
92
 
93
 
94
  history[-1][1] = ""
95
  if all_at_once:
96
+ for emb_model, doc, size, sub_vector in combinations:
97
+ documents_i = retrieve(query, int(num_docs), emb_model, sub_vector, size, doc)
98
  prompt_i = template.render(documents=documents_i, query=query)
99
  prompt_html = template_html.render(documents=documents, query=query)
100
 
 
102
  prev_hist = history[-1][1]
103
  if not has_balanced_backticks(prev_hist):
104
  prev_hist += "\n```\n"
105
+ prev_hist += f"\n\n## model {emb_model}, splitter {doc}, size {size}, sub vector {sub_vector}\n\n"
106
+ for character in generate_fn(model_name, prompt_i, history[:-1], oepnai_api_key):
107
  hist_chunk = character
108
  history[-1][1] = prev_hist + hist_chunk
109
  yield history, prompt_html
110
  else:
111
+ for character in generate_fn(model_name, prompt, history[:-1], oepnai_api_key):
112
  history[-1][1] = character
113
  yield history, prompt_html
114
 
 
135
 
136
 
137
  with gr.Row():
138
+ emb_model_kind = gr.Radio(choices=emb_models, value="bge", label="embedding model")
 
139
  sub_vector_size = gr.Radio(choices=sub_vectors, value="32", label="sub-vector size")
 
 
 
140
  chunk_size = gr.Radio(choices=chunk_sizes, value="2000", label="chunk size")
141
  splitter_type = gr.Radio(choices=splitters, value="nltk", label="splitter")
142
+ with gr.Row():
143
+ reranker_enabled = gr.Checkbox(value=False, label="Reranker enabled")
144
+ reranker_kind = gr.Radio(choices=emb_models, value="bge", label="Reranker model")
145
+ num_prerank_docs = gr.Slider(5, 80, label="Number of docs before reranker", step=1, value=20)
146
+ with gr.Row():
147
+ num_docs = gr.Slider(1, 20, label="number of docs", step=1, value=4)
148
+ all_at_once = gr.Checkbox(value=False, label="Run all at once")
149
+ model_name = gr.Radio(choices=inf_models, value=inf_models[0], label="Chat model")
150
+ oepnai_api_key = gr.Textbox(
151
+ show_label=False,
152
+ placeholder="OpenAI API key",
153
+ container=False,
154
+ )
155
 
156
  prompt_html = gr.HTML()
157
  # Turn off interactivity while generating if you click
158
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
159
+ bot, [chatbot, model_name, oepnai_api_key,
160
+ reranker_enabled,reranker_kind,num_prerank_docs,
161
+ num_docs, emb_model_kind, sub_vector_size, chunk_size, splitter_type, all_at_once
162
  ], [chatbot, prompt_html])
163
 
164
  # Turn it back on
 
166
 
167
  # Turn off interactivity while generating if you hit enter
168
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
169
+ bot, [chatbot, model_name,
170
+ num_docs, emb_model_kind, sub_vector_size, chunk_size, splitter_type
171
  ], [chatbot, prompt_html])
172
 
173
  # Turn it back on
backend/query_llm.py CHANGED
@@ -8,15 +8,19 @@ from huggingface_hub import InferenceClient
8
  from transformers import AutoTokenizer
9
 
10
 
11
- OPENAI_KEY = os.getenv("OPENAI_API_KEY")
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
- TOKENIZER = AutoTokenizer.from_pretrained(os.getenv("HF_MODEL"))
14
 
15
- HF_CLIENT = InferenceClient(
16
- os.getenv("HF_MODEL"),
17
- token=HF_TOKEN
18
- )
19
- OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)
 
 
 
 
 
 
20
 
21
  HF_GENERATE_KWARGS = {
22
  'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
@@ -34,7 +38,7 @@ OAI_GENERATE_KWARGS = {
34
  }
35
 
36
 
37
- def format_prompt(message: str, api_kind: str):
38
  """
39
  Formats the given message using a chat template.
40
 
@@ -48,15 +52,15 @@ def format_prompt(message: str, api_kind: str):
48
  # Create a list of message dictionaries with role and content
49
  messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
50
 
51
- if api_kind == "openai":
52
  return messages
53
- elif api_kind == "hf":
54
- return TOKENIZER.apply_chat_template(messages, tokenize=False)
55
- elif api_kind:
56
- raise ValueError("API is not supported")
57
 
58
 
59
- def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
60
  """
61
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
62
 
@@ -68,11 +72,11 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
68
  Returns a final string if an error occurs.
69
  """
70
 
71
- formatted_prompt = format_prompt(prompt, "hf")
72
  formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
73
 
74
  try:
75
- stream = HF_CLIENT.text_generation(
76
  formatted_prompt,
77
  **HF_GENERATE_KWARGS,
78
  stream=True,
@@ -93,7 +97,7 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
93
  raise gr.Error(f"Unhandled Exception: {str(e)}")
94
 
95
 
96
- def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
97
  """
98
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
99
 
@@ -104,11 +108,12 @@ def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
104
  Generator[str, None, str]: A generator yielding chunks of generated text.
105
  Returns a final string if an error occurs.
106
  """
107
- formatted_prompt = format_prompt(prompt, "openai")
 
108
 
109
  try:
110
- stream = OAI_CLIENT.chat.completions.create(
111
- model=os.getenv("OPENAI_MODEL"),
112
  messages=formatted_prompt,
113
  **OAI_GENERATE_KWARGS,
114
  stream=True
 
8
  from transformers import AutoTokenizer
9
 
10
 
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
 
12
 
13
+ hf_models = {
14
+ "mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
15
+ "mistral-7B 0.1": "mistralai/Mistral-7B-v0.1",
16
+ "llama 3": "meta-llama/Meta-Llama-3-70B-Instruct",
17
+ }
18
+ openai_models = {"gpt-4o","gpt-3.5-turbo-0125"}
19
+
20
+ tokenizers = {k: AutoTokenizer.from_pretrained(m) for k,m in hf_models.items()}
21
+ clients = {k: InferenceClient(
22
+ m, token=HF_TOKEN
23
+ ) for k,m in hf_models.items()}
24
 
25
  HF_GENERATE_KWARGS = {
26
  'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
 
38
  }
39
 
40
 
41
+ def format_prompt(message: str, model: str):
42
  """
43
  Formats the given message using a chat template.
44
 
 
52
  # Create a list of message dictionaries with role and content
53
  messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
54
 
55
+ if model in openai_models:
56
  return messages
57
+ elif model in hf_models:
58
+ return tokenizers[model].apply_chat_template(messages, tokenize=False)
59
+ else:
60
+ raise ValueError(f"Model {model} is not supported")
61
 
62
 
63
+ def generate_hf(model: str, prompt: str, history: str, _: str) -> Generator[str, None, str]:
64
  """
65
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
66
 
 
72
  Returns a final string if an error occurs.
73
  """
74
 
75
+ formatted_prompt = format_prompt(prompt, model)
76
  formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
77
 
78
  try:
79
+ stream = clients[model].text_generation(
80
  formatted_prompt,
81
  **HF_GENERATE_KWARGS,
82
  stream=True,
 
97
  raise gr.Error(f"Unhandled Exception: {str(e)}")
98
 
99
 
100
+ def generate_openai(model: str, prompt: str, history: str, api_key: str) -> Generator[str, None, str]:
101
  """
102
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
103
 
 
108
  Generator[str, None, str]: A generator yielding chunks of generated text.
109
  Returns a final string if an error occurs.
110
  """
111
+ formatted_prompt = format_prompt(prompt, model)
112
+ client = openai.Client(api_key=api_key)
113
 
114
  try:
115
+ stream = client.chat.completions.create(
116
+ model=model,
117
  messages=formatted_prompt,
118
  **OAI_GENERATE_KWARGS,
119
  stream=True
backend/semantic_search.py CHANGED
@@ -1,7 +1,7 @@
1
  import lancedb
2
  import os
3
  import gradio as gr
4
- from sentence_transformers import SentenceTransformer
5
 
6
 
7
  db = lancedb.connect(".lancedb")
@@ -22,20 +22,41 @@ BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
22
  retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5")
23
  retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
24
 
 
 
25
 
26
- def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type):
 
27
  if model_kind == "bge":
28
  query_vec = retriever_bge.encode(query)
29
  else:
30
  query_vec = retriever_minilm.encode(query)
31
 
 
 
 
32
  try:
33
  documents = table(
34
  f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}",
35
- ).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
36
  documents = [doc[TEXT_COLUMN] for doc in documents]
37
 
38
- return documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  except Exception as e:
41
  raise gr.Error(str(e))
 
1
  import lancedb
2
  import os
3
  import gradio as gr
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder
5
 
6
 
7
  db = lancedb.connect(".lancedb")
 
22
  retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5")
23
  retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
24
 
25
+ reranker_bge = CrossEncoder("BAAI/bge-reranker-large")
26
+ reranker_minilm = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
27
 
28
+
29
+ def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type, reranker_kind=None, pre_ranker_size=None):
30
  if model_kind == "bge":
31
  query_vec = retriever_bge.encode(query)
32
  else:
33
  query_vec = retriever_minilm.encode(query)
34
 
35
+ if pre_ranker_size is None:
36
+ pre_ranker_size = k
37
+
38
  try:
39
  documents = table(
40
  f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}",
41
+ ).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(pre_ranker_size).to_list()
42
  documents = [doc[TEXT_COLUMN] for doc in documents]
43
 
44
+ if reranker_kind is None:
45
+ return documents
46
+ # Pair the query with each document for re-ranking
47
+ query_document_pairs = [(query, text) for text in documents]
48
+
49
+ # Score documents using the reranker
50
+ if reranker_kind == "bge":
51
+ scores = reranker_bge.predict(query_document_pairs)
52
+ else:
53
+ scores = reranker_minilm.predict(query_document_pairs)
54
+
55
+ # Aggregate and sort the documents based on the scores
56
+ scored_documents = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
57
+
58
+ # Return the top K documents based on re-ranking
59
+ return [doc for doc, _ in scored_documents[:k]]
60
 
61
  except Exception as e:
62
  raise gr.Error(str(e))