orionweller commited on
Commit
53b3bb9
·
1 Parent(s): d89580e
app.py CHANGED
@@ -34,7 +34,7 @@ corpus_lookups = {}
34
  queries = {}
35
  q_lookups = {}
36
  qrels = {}
37
- datasets = ["scifact"] # others are too large for the Space unfortunately :(
38
  current_dataset = "scifact"
39
 
40
  def pool(last_hidden_states, attention_mask):
@@ -68,61 +68,45 @@ def load_model():
68
  tokenizer.pad_token = tokenizer.eos_token
69
  tokenizer.padding_side = "right"
70
 
71
- base_model_instance = AutoModel.from_pretrained(BASE_MODEL)
72
  model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL)
73
- model = model.merge_and_unload()
74
  model.eval()
75
 
76
- def save_faiss_index(index, dataset_name):
77
- index_path = f"{dataset_name}/faiss_index.bin"
78
- faiss.write_index(index, index_path)
79
- logger.info(f"Saved FAISS index for {dataset_name} to {index_path}")
80
-
81
  def load_faiss_index(dataset_name):
82
  index_path = f"{dataset_name}/faiss_index.bin"
83
  if os.path.exists(index_path):
84
  logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
85
- return faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
86
  return None
87
 
88
- def load_corpus_embeddings(dataset_name):
89
- global retrievers, corpus_lookups
90
- corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
91
- index_files = glob.glob(corpus_path)
92
- logger.info(f'Loading {len(index_files)} files into index for {dataset_name}.')
93
-
94
- # Try to load existing FAISS index
95
  faiss_index = load_faiss_index(dataset_name)
96
-
97
  if faiss_index is None:
98
- # If no existing index, create a new one
99
- p_reps_0, p_lookup_0 = pickle_load(index_files[0])
100
- retrievers[dataset_name] = FaissFlatSearcher(p_reps_0)
101
-
102
- shards = [(p_reps_0, p_lookup_0)] + [pickle_load(f) for f in index_files[1:]]
103
- corpus_lookups[dataset_name] = []
104
-
105
- for p_reps, p_lookup in tqdm.tqdm(shards, desc=f'Loading shards into index for {dataset_name}', total=len(index_files)):
106
- retrievers[dataset_name].add(p_reps)
107
- corpus_lookups[dataset_name] += p_lookup
108
-
109
- # Save the newly created index
110
- save_faiss_index(retrievers[dataset_name].index, dataset_name)
111
- else:
112
- # Use the loaded index
113
- retrievers[dataset_name] = FaissFlatSearcher(faiss_index)
114
-
115
- # Load corpus lookups
116
- corpus_lookups[dataset_name] = []
117
- for file in index_files:
118
- _, p_lookup = pickle_load(file)
119
- corpus_lookups[dataset_name] += p_lookup
120
-
121
 
122
- def pickle_load(path):
123
- with open(path, 'rb') as f:
124
- reps, lookup = pickle.load(f)
125
- return np.array(reps), lookup
 
 
 
 
 
 
126
 
127
  def load_queries(dataset_name):
128
  global queries, q_lookups, qrels
@@ -143,7 +127,6 @@ def load_queries(dataset_name):
143
  @spaces.GPU
144
  def encode_queries(dataset_name, postfix):
145
  global queries, tokenizer, model
146
- model = model.cuda()
147
  input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[dataset_name]]
148
 
149
  encoded_embeds = []
@@ -161,15 +144,8 @@ def encode_queries(dataset_name, postfix):
161
  embeds = F.normalize(embeds, p=2, dim=-1)
162
  encoded_embeds.append(embeds.cpu().numpy())
163
 
164
-
165
- # remove model from GPU
166
- model = model.cpu()
167
  return np.concatenate(encoded_embeds, axis=0)
168
 
169
- def search_queries(dataset_name, q_reps, depth=1000):
170
- all_scores, all_indices = retrievers[dataset_name].search(q_reps, depth)
171
- psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
172
- return all_scores, np.array(psg_indices)
173
 
174
  def evaluate(qrels, results, k_values):
175
  evaluator = pytrec_eval.RelevanceEvaluator(
@@ -187,8 +163,8 @@ def evaluate(qrels, results, k_values):
187
  def run_evaluation(dataset, postfix):
188
  global current_dataset
189
 
190
- if dataset not in retrievers or dataset not in queries:
191
- load_corpus_embeddings(dataset)
192
  load_queries(dataset)
193
 
194
  current_dataset = dataset
@@ -208,16 +184,14 @@ def run_evaluation(dataset, postfix):
208
 
209
  def gradio_interface(dataset, postfix):
210
  if 'model' not in globals() or model is None:
211
- # Load model and initial datasets
212
  load_model()
213
  for dataset in datasets:
214
  print(f"Loading dataset: {dataset}")
215
- load_corpus_embeddings(dataset)
216
  load_queries(dataset)
217
 
218
  return run_evaluation(dataset, postfix)
219
 
220
-
221
  # Create Gradio interface
222
  iface = gr.Interface(
223
  fn=gradio_interface,
@@ -230,7 +204,7 @@ iface = gr.Interface(
230
  description="Select a dataset and enter a prompt to evaluate the model's performance. Note: it takes about **ten seconds** to evaluate.",
231
  examples=[
232
  ["scifact", ""],
233
- ["scifact", "When judging the relevance of a document, focus on the pragmatics of the query and consider irrelevant any documents for which the user would have used a different query."]
234
  ],
235
  cache_examples=True,
236
  )
 
34
  queries = {}
35
  q_lookups = {}
36
  qrels = {}
37
+ datasets = ["scifact"]
38
  current_dataset = "scifact"
39
 
40
  def pool(last_hidden_states, attention_mask):
 
68
  tokenizer.pad_token = tokenizer.eos_token
69
  tokenizer.padding_side = "right"
70
 
71
+ base_model_instance = AutoModel.from_pretrained(BASE_MODEL, device_map="auto", torch_dtype=torch.float16)
72
  model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL)
 
73
  model.eval()
74
 
 
 
 
 
 
75
  def load_faiss_index(dataset_name):
76
  index_path = f"{dataset_name}/faiss_index.bin"
77
  if os.path.exists(index_path):
78
  logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
79
+ return faiss.read_index(index_path, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
80
  return None
81
 
82
+ def search_queries(dataset_name, q_reps, depth=1000):
 
 
 
 
 
 
83
  faiss_index = load_faiss_index(dataset_name)
 
84
  if faiss_index is None:
85
+ raise ValueError(f"No FAISS index found for dataset {dataset_name}")
86
+
87
+ # Ensure q_reps is a 2D numpy array of the correct type
88
+ q_reps = np.ascontiguousarray(q_reps.astype('float32'))
89
+
90
+ # Perform the search
91
+ all_scores, all_indices = faiss_index.search(q_reps, depth)
92
+
93
+ psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
94
+
95
+ # Clean up
96
+ del faiss_index
97
+
98
+ return all_scores, np.array(psg_indices)
 
 
 
 
 
 
 
 
 
99
 
100
+ def load_corpus_lookups(dataset_name):
101
+ global corpus_lookups
102
+ corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
103
+ index_files = glob.glob(corpus_path)
104
+
105
+ corpus_lookups[dataset_name] = []
106
+ for file in index_files:
107
+ with open(file, 'rb') as f:
108
+ _, p_lookup = pickle.load(f)
109
+ corpus_lookups[dataset_name] += p_lookup
110
 
111
  def load_queries(dataset_name):
112
  global queries, q_lookups, qrels
 
127
  @spaces.GPU
128
  def encode_queries(dataset_name, postfix):
129
  global queries, tokenizer, model
 
130
  input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[dataset_name]]
131
 
132
  encoded_embeds = []
 
144
  embeds = F.normalize(embeds, p=2, dim=-1)
145
  encoded_embeds.append(embeds.cpu().numpy())
146
 
 
 
 
147
  return np.concatenate(encoded_embeds, axis=0)
148
 
 
 
 
 
149
 
150
  def evaluate(qrels, results, k_values):
151
  evaluator = pytrec_eval.RelevanceEvaluator(
 
163
  def run_evaluation(dataset, postfix):
164
  global current_dataset
165
 
166
+ if dataset not in corpus_lookups or dataset not in queries:
167
+ load_corpus_lookups(dataset)
168
  load_queries(dataset)
169
 
170
  current_dataset = dataset
 
184
 
185
  def gradio_interface(dataset, postfix):
186
  if 'model' not in globals() or model is None:
 
187
  load_model()
188
  for dataset in datasets:
189
  print(f"Loading dataset: {dataset}")
190
+ load_corpus_lookups(dataset)
191
  load_queries(dataset)
192
 
193
  return run_evaluation(dataset, postfix)
194
 
 
195
  # Create Gradio interface
196
  iface = gr.Interface(
197
  fn=gradio_interface,
 
204
  description="Select a dataset and enter a prompt to evaluate the model's performance. Note: it takes about **ten seconds** to evaluate.",
205
  examples=[
206
  ["scifact", ""],
207
+ ["scifact", "Think carefully about these conditions when determining relevance."]
208
  ],
209
  cache_examples=True,
210
  )
scifact/corpus_emb.0.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0bb98e68350983519732b0b39e8f98ec0225abd2c68775e7317da9b17f0db1dd
3
- size 21247618
 
 
 
 
scifact/corpus_emb.1.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3dd3501342754aeb2ffb895480868e0976895bded3e5accbd8e5b6fa404e5484
3
- size 21247619
 
 
 
 
scifact/corpus_emb.2.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0e1a98c698cbe367bc1abc789da76794a8e79e92743059b26faafbd34808aa15
3
- size 21247619
 
 
 
 
scifact/corpus_emb.3.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:911c8d6654bfb14a3d68363c96a70462348cfbbf35a591e020877ed28591339c
3
- size 21231225