Warlord-K commited on
Commit
01df155
·
1 Parent(s): f1ca1a4

Proper Chatbot Interface

Browse files
Files changed (1) hide show
  1. app.py +30 -8
app.py CHANGED
@@ -21,11 +21,15 @@ MAX_INPUT_TOKEN_LENGTH = 4000
21
  EMBED_DIM = 1024
22
  K = 10
23
  EF = 100
 
24
  SEARCH_INDEX = "search_index.bin"
25
  EMBEDDINGS_FILE = "embeddings.npy"
26
  DOCUMENT_DATASET = "chunked_data.parquet"
27
  COSINE_THRESHOLD = 0.7
28
 
 
 
 
29
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
30
  print("Running on device:", torch_device)
31
  print("CPU threads:", torch.get_num_threads())
@@ -36,6 +40,11 @@ cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length
36
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"])
38
 
 
 
 
 
 
39
 
40
  def create_qa_prompt(query, relevant_chunks):
41
  stuffed_context = " ".join(relevant_chunks)
@@ -112,7 +121,7 @@ def get_completion(
112
  return response["choices"][0]["message"]["content"] if not stream else response
113
 
114
 
115
- # load the index for the PEFT docs
116
  def load_hnsw_index(index_file):
117
  # Load the HNSW index from the specified file
118
  index = hnswlib.Index(space="ip", dim=EMBED_DIM)
@@ -120,7 +129,7 @@ def load_hnsw_index(index_file):
120
  return index
121
 
122
 
123
- # create the index for the PEFT docs from numpy embeddings
124
  # avoid the arch mismatches when creating search index
125
  def create_hnsw_index(embeddings_file, M=16, efC=100):
126
  embeddings = np.load(embeddings_file)
@@ -181,7 +190,7 @@ DEFAULT_MAX_NEW_TOKENS = 1024
181
  MAX_INPUT_TOKEN_LENGTH = 4000
182
 
183
  DESCRIPTION = """
184
- # PEFT Docs QA Chatbot 🤗
185
  """
186
 
187
  LICENSE = """
@@ -285,6 +294,18 @@ def check_input_token_length(message: str, chat_history: list[tuple[str, str]],
285
  f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
286
  )
287
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  search_index = create_hnsw_index(EMBEDDINGS_FILE) # load_hnsw_index(SEARCH_INDEX)
290
  data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()
@@ -341,11 +362,12 @@ with gr.Blocks(css="style.css") as demo:
341
 
342
  gr.Examples(
343
  examples=[
344
- "What is 🤗 PEFT?",
345
- "How do I create a LoraConfig?",
346
- "What are the different tuners supported?",
347
- "How do I use LoRA with custom models?",
348
- "What are the different real-world applications that I can use PEFT for?",
 
349
  ],
350
  inputs=textbox,
351
  outputs=[textbox, chatbot],
 
21
  EMBED_DIM = 1024
22
  K = 10
23
  EF = 100
24
+ TEXT_FILE = 'data.txt'
25
  SEARCH_INDEX = "search_index.bin"
26
  EMBEDDINGS_FILE = "embeddings.npy"
27
  DOCUMENT_DATASET = "chunked_data.parquet"
28
  COSINE_THRESHOLD = 0.7
29
 
30
+
31
+
32
+
33
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
34
  print("Running on device:", torch_device)
35
  print("CPU threads:", torch.get_num_threads())
 
40
 
41
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"])
42
 
43
+ def read_text_from_file(file_path):
44
+ with open(file_path, "r") as text_file:
45
+ text = text_file.read()
46
+ texts = text.split("&&")
47
+ return [t.strip() for t in texts]
48
 
49
  def create_qa_prompt(query, relevant_chunks):
50
  stuffed_context = " ".join(relevant_chunks)
 
121
  return response["choices"][0]["message"]["content"] if not stream else response
122
 
123
 
124
+ # load the index for the data
125
  def load_hnsw_index(index_file):
126
  # Load the HNSW index from the specified file
127
  index = hnswlib.Index(space="ip", dim=EMBED_DIM)
 
129
  return index
130
 
131
 
132
+ # create the index for the data from numpy embeddings
133
  # avoid the arch mismatches when creating search index
134
  def create_hnsw_index(embeddings_file, M=16, efC=100):
135
  embeddings = np.load(embeddings_file)
 
190
  MAX_INPUT_TOKEN_LENGTH = 4000
191
 
192
  DESCRIPTION = """
193
+ # AVA Southampton Chatbot 🤗
194
  """
195
 
196
  LICENSE = """
 
294
  f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
295
  )
296
 
297
+ if not os.path.exists(TEXT_FILE):
298
+ os.system(f"wget -O {TEXT_FILE} https://huggingface.co/spaces/Slycat/Southampton-Similarity/resolve/main/Southampton.txt")
299
+
300
+ if not os.path.exists(EMBEDDINGS_FILE):
301
+ texts = read_text_from_file(TEXT_FILE)
302
+ embeddings = biencoder.encode(texts, normalize_embeddings=True)
303
+ np.save(EMBEDDINGS_FILE,embeddings)
304
+
305
+ if not os.path.exists(DOCUMENT_DATASET):
306
+ texts = read_text_from_file(TEXT_FILE)
307
+ df = pd.DataFrame(texts, columns = ["chunk_content"])
308
+ df.to_parquet(DOCUMENT_DATASET,index=False)
309
 
310
  search_index = create_hnsw_index(EMBEDDINGS_FILE) # load_hnsw_index(SEARCH_INDEX)
311
  data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()
 
362
 
363
  gr.Examples(
364
  examples=[
365
+ "What is University of Southampton?",
366
+ "Is University of Southampton Good?",
367
+ "What is sports facility at southampton university?",
368
+ "How big is the Southampton campus?",
369
+ "What are the rankings of southampton university?",
370
+ "What research facilities does the Southampton university offer?"
371
  ],
372
  inputs=textbox,
373
  outputs=[textbox, chatbot],