Spaces:
Runtime error
Runtime error
NCTCMumbai
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"""
|
3 |
Credit to Derek Thomas, [email protected]
|
4 |
"""
|
5 |
-
|
6 |
import subprocess
|
7 |
|
8 |
# subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
|
@@ -59,46 +59,68 @@ def bot(history, cross_encoder):
|
|
59 |
raise ValueError("Empty string was submitted")
|
60 |
|
61 |
logger.warning('Retrieving documents...')
|
62 |
-
# Retrieve documents relevant to query
|
63 |
-
document_start = perf_counter()
|
64 |
-
|
65 |
-
query_vec = retriever.encode(query)
|
66 |
-
logger.warning(f'Finished query vec')
|
67 |
-
doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
|
68 |
-
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
else:
|
80 |
-
|
81 |
-
|
82 |
-
sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
|
83 |
-
logger.warning(f'Finished cross encoder {len(documents)}')
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
|
104 |
with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
|
@@ -128,7 +150,7 @@ with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
|
|
128 |
)
|
129 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
130 |
|
131 |
-
cross_encoder = gr.Radio(choices=['MiniLM-L6v2','BGE reranker'], value='BGE reranker',label="Embeddings", info="Choose MiniLM for Speed, BGE reranker for accuracy")
|
132 |
|
133 |
prompt_html = gr.HTML()
|
134 |
# Turn off interactivity while generating if you click
|
|
|
2 |
"""
|
3 |
Credit to Derek Thomas, [email protected]
|
4 |
"""
|
5 |
+
from ragatouille import RAGPretrainedModel
|
6 |
import subprocess
|
7 |
|
8 |
# subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
|
|
|
59 |
raise ValueError("Empty string was submitted")
|
60 |
|
61 |
logger.warning('Retrieving documents...')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
# if COLBERT RAGATATOUILLE PROCEDURE :
|
64 |
+
if cross_encoder=='ColBERT':
|
65 |
+
gr.Warning('Retrieving using ColBERT')
|
66 |
+
RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
|
67 |
+
RAG_db=RAG.from_index('.ragatouille/colbert/indexes/mockingbird')
|
68 |
+
documents_full=RAG_db.search(query)
|
69 |
+
|
70 |
+
documents=[item['content'] for item in documents_full]
|
71 |
+
# Create Prompt
|
72 |
+
prompt = template.render(documents=documents, query=query)
|
73 |
+
prompt_html = template_html.render(documents=documents, query=query)
|
74 |
+
|
75 |
+
generate_fn = generate_hf
|
76 |
+
|
77 |
+
history[-1][1] = ""
|
78 |
+
for character in generate_fn(prompt, history[:-1]):
|
79 |
+
history[-1][1] = character
|
80 |
+
print('Final history is ',history)
|
81 |
+
yield history, prompt_html
|
82 |
else:
|
83 |
+
# Retrieve documents relevant to query
|
84 |
+
document_start = perf_counter()
|
|
|
|
|
85 |
|
86 |
+
query_vec = retriever.encode(query)
|
87 |
+
logger.warning(f'Finished query vec')
|
88 |
+
doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
logger.warning(f'Finished search')
|
93 |
+
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
|
94 |
+
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
95 |
+
logger.warning(f'start cross encoder {len(documents)}')
|
96 |
+
# Retrieve documents relevant to query
|
97 |
+
query_doc_pair = [[query, doc] for doc in documents]
|
98 |
+
if cross_encoder=='MiniLM-L6v2' :
|
99 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
100 |
+
elif cross_encoder=='BGE reranker':
|
101 |
+
cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
|
102 |
+
|
103 |
+
cross_scores = cross_encoder.predict(query_doc_pair)
|
104 |
+
sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
|
105 |
+
logger.warning(f'Finished cross encoder {len(documents)}')
|
106 |
+
|
107 |
+
documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
|
108 |
+
logger.warning(f'num documents {len(documents)}')
|
109 |
+
|
110 |
+
document_time = perf_counter() - document_start
|
111 |
+
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
112 |
+
|
113 |
+
# Create Prompt
|
114 |
+
prompt = template.render(documents=documents, query=query)
|
115 |
+
prompt_html = template_html.render(documents=documents, query=query)
|
116 |
+
|
117 |
+
generate_fn = generate_hf
|
118 |
+
|
119 |
+
history[-1][1] = ""
|
120 |
+
for character in generate_fn(prompt, history[:-1]):
|
121 |
+
history[-1][1] = character
|
122 |
+
print('Final history is ',history)
|
123 |
+
yield history, prompt_html
|
124 |
|
125 |
|
126 |
with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
|
|
|
150 |
)
|
151 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
152 |
|
153 |
+
cross_encoder = gr.Radio(choices=['MiniLM-L6v2','BGE reranker','ColBERT'], value='BGE reranker',label="Embeddings", info="Choose MiniLM for Speed, BGE reranker for accuracy,ColBERT for both")
|
154 |
|
155 |
prompt_html = gr.HTML()
|
156 |
# Turn off interactivity while generating if you click
|