Spaces:
Runtime error
Runtime error
Proper Chatbot Interface
Browse files
app.py
CHANGED
@@ -21,11 +21,15 @@ MAX_INPUT_TOKEN_LENGTH = 4000
|
|
21 |
EMBED_DIM = 1024
|
22 |
K = 10
|
23 |
EF = 100
|
|
|
24 |
SEARCH_INDEX = "search_index.bin"
|
25 |
EMBEDDINGS_FILE = "embeddings.npy"
|
26 |
DOCUMENT_DATASET = "chunked_data.parquet"
|
27 |
COSINE_THRESHOLD = 0.7
|
28 |
|
|
|
|
|
|
|
29 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
print("Running on device:", torch_device)
|
31 |
print("CPU threads:", torch.get_num_threads())
|
@@ -36,6 +40,11 @@ cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length
|
|
36 |
|
37 |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"])
|
38 |
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def create_qa_prompt(query, relevant_chunks):
|
41 |
stuffed_context = " ".join(relevant_chunks)
|
@@ -112,7 +121,7 @@ def get_completion(
|
|
112 |
return response["choices"][0]["message"]["content"] if not stream else response
|
113 |
|
114 |
|
115 |
-
# load the index for the
|
116 |
def load_hnsw_index(index_file):
|
117 |
# Load the HNSW index from the specified file
|
118 |
index = hnswlib.Index(space="ip", dim=EMBED_DIM)
|
@@ -120,7 +129,7 @@ def load_hnsw_index(index_file):
|
|
120 |
return index
|
121 |
|
122 |
|
123 |
-
# create the index for the
|
124 |
# avoid the arch mismatches when creating search index
|
125 |
def create_hnsw_index(embeddings_file, M=16, efC=100):
|
126 |
embeddings = np.load(embeddings_file)
|
@@ -181,7 +190,7 @@ DEFAULT_MAX_NEW_TOKENS = 1024
|
|
181 |
MAX_INPUT_TOKEN_LENGTH = 4000
|
182 |
|
183 |
DESCRIPTION = """
|
184 |
-
#
|
185 |
"""
|
186 |
|
187 |
LICENSE = """
|
@@ -285,6 +294,18 @@ def check_input_token_length(message: str, chat_history: list[tuple[str, str]],
|
|
285 |
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
|
286 |
)
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
search_index = create_hnsw_index(EMBEDDINGS_FILE) # load_hnsw_index(SEARCH_INDEX)
|
290 |
data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()
|
@@ -341,11 +362,12 @@ with gr.Blocks(css="style.css") as demo:
|
|
341 |
|
342 |
gr.Examples(
|
343 |
examples=[
|
344 |
-
"What is
|
345 |
-
"
|
346 |
-
"What
|
347 |
-
"How
|
348 |
-
"What are the
|
|
|
349 |
],
|
350 |
inputs=textbox,
|
351 |
outputs=[textbox, chatbot],
|
|
|
21 |
EMBED_DIM = 1024
|
22 |
K = 10
|
23 |
EF = 100
|
24 |
+
TEXT_FILE = 'data.txt'
|
25 |
SEARCH_INDEX = "search_index.bin"
|
26 |
EMBEDDINGS_FILE = "embeddings.npy"
|
27 |
DOCUMENT_DATASET = "chunked_data.parquet"
|
28 |
COSINE_THRESHOLD = 0.7
|
29 |
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
print("Running on device:", torch_device)
|
35 |
print("CPU threads:", torch.get_num_threads())
|
|
|
40 |
|
41 |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"])
|
42 |
|
43 |
+
def read_text_from_file(file_path):
|
44 |
+
with open(file_path, "r") as text_file:
|
45 |
+
text = text_file.read()
|
46 |
+
texts = text.split("&&")
|
47 |
+
return [t.strip() for t in texts]
|
48 |
|
49 |
def create_qa_prompt(query, relevant_chunks):
|
50 |
stuffed_context = " ".join(relevant_chunks)
|
|
|
121 |
return response["choices"][0]["message"]["content"] if not stream else response
|
122 |
|
123 |
|
124 |
+
# load the index for the data
|
125 |
def load_hnsw_index(index_file):
|
126 |
# Load the HNSW index from the specified file
|
127 |
index = hnswlib.Index(space="ip", dim=EMBED_DIM)
|
|
|
129 |
return index
|
130 |
|
131 |
|
132 |
+
# create the index for the data from numpy embeddings
|
133 |
# avoid the arch mismatches when creating search index
|
134 |
def create_hnsw_index(embeddings_file, M=16, efC=100):
|
135 |
embeddings = np.load(embeddings_file)
|
|
|
190 |
MAX_INPUT_TOKEN_LENGTH = 4000
|
191 |
|
192 |
DESCRIPTION = """
|
193 |
+
# AVA Southampton Chatbot 🤗
|
194 |
"""
|
195 |
|
196 |
LICENSE = """
|
|
|
294 |
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
|
295 |
)
|
296 |
|
297 |
+
if not os.path.exists(TEXT_FILE):
|
298 |
+
os.system(f"wget -O {TEXT_FILE} https://huggingface.co/spaces/Slycat/Southampton-Similarity/resolve/main/Southampton.txt")
|
299 |
+
|
300 |
+
if not os.path.exists(EMBEDDINGS_FILE):
|
301 |
+
texts = read_text_from_file(TEXT_FILE)
|
302 |
+
embeddings = biencoder.encode(texts, normalize_embeddings=True)
|
303 |
+
np.save(EMBEDDINGS_FILE,embeddings)
|
304 |
+
|
305 |
+
if not os.path.exists(DOCUMENT_DATASET):
|
306 |
+
texts = read_text_from_file(TEXT_FILE)
|
307 |
+
df = pd.DataFrame(texts, columns = ["chunk_content"])
|
308 |
+
df.to_parquet(DOCUMENT_DATASET,index=False)
|
309 |
|
310 |
search_index = create_hnsw_index(EMBEDDINGS_FILE) # load_hnsw_index(SEARCH_INDEX)
|
311 |
data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()
|
|
|
362 |
|
363 |
gr.Examples(
|
364 |
examples=[
|
365 |
+
"What is University of Southampton?",
|
366 |
+
"Is University of Southampton Good?",
|
367 |
+
"What is sports facility at southampton university?",
|
368 |
+
"How big is the Southampton campus?",
|
369 |
+
"What are the rankings of southampton university?",
|
370 |
+
"What research facilities does the Southampton university offer?"
|
371 |
],
|
372 |
inputs=textbox,
|
373 |
outputs=[textbox, chatbot],
|