Spaces:
Sleeping
Sleeping
fix
Browse files
app.py
CHANGED
@@ -43,13 +43,13 @@ def load_title_embeddings():
|
|
43 |
|
44 |
|
45 |
def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
|
46 |
-
batch_dict = tokenizer(f"query: {input_text}", max_length=512, padding=True, truncation=True, return_tensors='pt')
|
47 |
with torch.no_grad():
|
48 |
outputs = model(**batch_dict)
|
49 |
query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
50 |
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
51 |
|
52 |
-
_, ids = index.search(x=
|
53 |
retrieved_titles = []
|
54 |
retrieved_pids = []
|
55 |
|
|
|
43 |
|
44 |
|
45 |
def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
|
46 |
+
batch_dict = tokenizer([f"query: {input_text}"], max_length=512, padding=True, truncation=True, return_tensors='pt')
|
47 |
with torch.no_grad():
|
48 |
outputs = model(**batch_dict)
|
49 |
query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
50 |
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
51 |
|
52 |
+
_, ids = index.search(x=query_embeddings.detach().numpy().copy(), k=top_k)
|
53 |
retrieved_titles = []
|
54 |
retrieved_pids = []
|
55 |
|