spacemanidol
commited on
Commit
•
ac9d0cb
1
Parent(s):
08e7a44
Update README.md
Browse files
README.md
CHANGED
@@ -2997,14 +2997,14 @@ document_tokens = tokenizer(documents, padding=True, truncation=True, return_te
|
|
2997 |
# Compute token embeddings
|
2998 |
with torch.no_grad():
|
2999 |
query_embeddings = model(**query_tokens)[0][:, 0]
|
3000 |
-
|
3001 |
|
3002 |
|
3003 |
# normalize embeddings
|
3004 |
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
|
3005 |
-
|
3006 |
|
3007 |
-
scores = torch.mm(query_embeddings,
|
3008 |
for query, query_scores in zip(queries, scores):
|
3009 |
doc_score_pairs = list(zip(documents, query_scores))
|
3010 |
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|
|
|
2997 |
# Compute token embeddings
|
2998 |
with torch.no_grad():
|
2999 |
query_embeddings = model(**query_tokens)[0][:, 0]
|
3000 |
+
document_embeddings = model(**document_tokens)[0][:, 0]
|
3001 |
|
3002 |
|
3003 |
# normalize embeddings
|
3004 |
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
|
3005 |
+
document_embeddings = torch.nn.functional.normalize(document_embeddings, p=2, dim=1)
|
3006 |
|
3007 |
+
scores = torch.mm(query_embeddings, document_embeddings.transpose(0, 1))
|
3008 |
for query, query_scores in zip(queries, scores):
|
3009 |
doc_score_pairs = list(zip(documents, query_scores))
|
3010 |
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|