NCTCMumbai commited on
Commit
4a192fa
·
verified ·
1 Parent(s): d63f3fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -39
app.py CHANGED
@@ -2,7 +2,7 @@
2
  """
3
  Credit to Derek Thomas, [email protected]
4
  """
5
-
6
  import subprocess
7
 
8
  # subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
@@ -59,46 +59,68 @@ def bot(history, cross_encoder):
59
  raise ValueError("Empty string was submitted")
60
 
61
  logger.warning('Retrieving documents...')
62
- # Retrieve documents relevant to query
63
- document_start = perf_counter()
64
-
65
- query_vec = retriever.encode(query)
66
- logger.warning(f'Finished query vec')
67
- doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
68
-
69
 
70
-
71
- logger.warning(f'Finished search')
72
- documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
73
- documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
74
- logger.warning(f'start cross encoder {len(documents)}')
75
- # Retrieve documents relevant to query
76
- query_doc_pair = [[query, doc] for doc in documents]
77
- if cross_encoder=='MiniLM-L6v2' :
78
- cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
 
 
 
 
 
 
 
 
 
 
79
  else:
80
- cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
81
- cross_scores = cross_encoder.predict(query_doc_pair)
82
- sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
83
- logger.warning(f'Finished cross encoder {len(documents)}')
84
 
85
- documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
86
- logger.warning(f'num documents {len(documents)}')
87
-
88
- document_time = perf_counter() - document_start
89
- logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
90
-
91
- # Create Prompt
92
- prompt = template.render(documents=documents, query=query)
93
- prompt_html = template_html.render(documents=documents, query=query)
94
-
95
- generate_fn = generate_hf
96
-
97
- history[-1][1] = ""
98
- for character in generate_fn(prompt, history[:-1]):
99
- history[-1][1] = character
100
- print('Final history is ',history)
101
- yield history, prompt_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
@@ -128,7 +150,7 @@ with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
128
  )
129
  txt_btn = gr.Button(value="Submit text", scale=1)
130
 
131
- cross_encoder = gr.Radio(choices=['MiniLM-L6v2','BGE reranker'], value='BGE reranker',label="Embeddings", info="Choose MiniLM for Speed, BGE reranker for accuracy")
132
 
133
  prompt_html = gr.HTML()
134
  # Turn off interactivity while generating if you click
 
2
  """
3
  Credit to Derek Thomas, [email protected]
4
  """
5
+ from ragatouille import RAGPretrainedModel
6
  import subprocess
7
 
8
  # subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
 
59
  raise ValueError("Empty string was submitted")
60
 
61
  logger.warning('Retrieving documents...')
 
 
 
 
 
 
 
62
 
63
+ # if COLBERT RAGATATOUILLE PROCEDURE :
64
+ if cross_encoder=='ColBERT':
65
+ gr.Warning('Retrieving using ColBERT')
66
+ RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
67
+ RAG_db=RAG.from_index('.ragatouille/colbert/indexes/mockingbird')
68
+ documents_full=RAG_db.search(query)
69
+
70
+ documents=[item['content'] for item in documents_full]
71
+ # Create Prompt
72
+ prompt = template.render(documents=documents, query=query)
73
+ prompt_html = template_html.render(documents=documents, query=query)
74
+
75
+ generate_fn = generate_hf
76
+
77
+ history[-1][1] = ""
78
+ for character in generate_fn(prompt, history[:-1]):
79
+ history[-1][1] = character
80
+ print('Final history is ',history)
81
+ yield history, prompt_html
82
  else:
83
+ # Retrieve documents relevant to query
84
+ document_start = perf_counter()
 
 
85
 
86
+ query_vec = retriever.encode(query)
87
+ logger.warning(f'Finished query vec')
88
+ doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
89
+
90
+
91
+
92
+ logger.warning(f'Finished search')
93
+ documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
94
+ documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
95
+ logger.warning(f'start cross encoder {len(documents)}')
96
+ # Retrieve documents relevant to query
97
+ query_doc_pair = [[query, doc] for doc in documents]
98
+ if cross_encoder=='MiniLM-L6v2' :
99
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
100
+ elif cross_encoder=='BGE reranker':
101
+ cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
102
+
103
+ cross_scores = cross_encoder.predict(query_doc_pair)
104
+ sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
105
+ logger.warning(f'Finished cross encoder {len(documents)}')
106
+
107
+ documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
108
+ logger.warning(f'num documents {len(documents)}')
109
+
110
+ document_time = perf_counter() - document_start
111
+ logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
112
+
113
+ # Create Prompt
114
+ prompt = template.render(documents=documents, query=query)
115
+ prompt_html = template_html.render(documents=documents, query=query)
116
+
117
+ generate_fn = generate_hf
118
+
119
+ history[-1][1] = ""
120
+ for character in generate_fn(prompt, history[:-1]):
121
+ history[-1][1] = character
122
+ print('Final history is ',history)
123
+ yield history, prompt_html
124
 
125
 
126
  with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
 
150
  )
151
  txt_btn = gr.Button(value="Submit text", scale=1)
152
 
153
+ cross_encoder = gr.Radio(choices=['MiniLM-L6v2','BGE reranker','ColBERT'], value='BGE reranker',label="Embeddings", info="Choose MiniLM for Speed, BGE reranker for accuracy,ColBERT for both")
154
 
155
  prompt_html = gr.HTML()
156
  # Turn off interactivity while generating if you click