Ramon Meffert
commited on
Commit
•
83870cc
1
Parent(s):
8bbe3aa
Add base model retriever
Browse files- README.md +48 -0
- main.py → base_model/main.py +7 -6
- base_model/retriever.py +53 -24
- poetry.lock +29 -1
- pyproject.toml +1 -0
README.md
CHANGED
@@ -25,3 +25,51 @@ De meeste QA systemen bestaan uit twee onderdelen:
|
|
25 |
|
26 |
- Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
|
27 |
- Overview van open-domain question answering technieken: <https://lilianweng.github.io/posts/2020-10-29-odqa/>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
- Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
|
27 |
- Overview van open-domain question answering technieken: <https://lilianweng.github.io/posts/2020-10-29-odqa/>
|
28 |
+
|
29 |
+
## Base model
|
30 |
+
|
31 |
+
Tot nu toe alleen een retriever die adhv een vraag de top-k relevante documents
|
32 |
+
ophaalt. Haalt voor veel vragen wel hoge similarity scores, maar de documents
|
33 |
+
die die ophaalt zijn meestal niet erg relevant.
|
34 |
+
|
35 |
+
```bash
|
36 |
+
poetry shell
|
37 |
+
cd base_model
|
38 |
+
poetry run python main.py
|
39 |
+
```
|
40 |
+
|
41 |
+
### Voorbeeld
|
42 |
+
|
43 |
+
"What is the perplexity of a language model?"
|
44 |
+
|
45 |
+
> Result 1 (score: 74.10):
|
46 |
+
> Figure 10 .17 A sample alignment between sentences in English and French, with
|
47 |
+
> sentences extracted from Antoine de Saint-Exupery's Le Petit Prince and a
|
48 |
+
> hypothetical translation. Sentence alignment takes sentences e 1 , ..., e n ,
|
49 |
+
> and f 1 , ..., f n and finds minimal > sets of sentences that are translations
|
50 |
+
> of each other, including single sentence mappings like (e 1 ,f 1 ), (e 4 -f 3
|
51 |
+
> ), (e 5 -f 4 ), (e 6 -f 6 ) as well as 2-1 alignments (e 2 /e 3 ,f 2 ), (e 7
|
52 |
+
> /e 8 -f 7 ), and null alignments (f 5 ).
|
53 |
+
>
|
54 |
+
> Result 2 (score: 74.23):
|
55 |
+
> Character or word overlap-based metrics like chrF (or BLEU, or etc.) are
|
56 |
+
> mainly used to compare two systems, with the goal of answering questions like:
|
57 |
+
> did the new algorithm we just invented improve our MT system? To know if the
|
58 |
+
> difference between the chrF scores of two > MT systems is a significant
|
59 |
+
> difference, we use the paired bootstrap test, or the similar randomization
|
60 |
+
> test.
|
61 |
+
>
|
62 |
+
> Result 3 (score: 74.43):
|
63 |
+
> The model thus predicts the class negative for the test sentence.
|
64 |
+
>
|
65 |
+
> Result 4 (score: 74.95):
|
66 |
+
> Translating from languages with extensive pro-drop, like Chinese or Japanese,
|
67 |
+
> to non-pro-drop languages like English can be difficult since the model must
|
68 |
+
> somehow identify each zero and recover who or what is being talked about in
|
69 |
+
> order to insert the proper pronoun.
|
70 |
+
>
|
71 |
+
> Result 5 (score: 76.22):
|
72 |
+
> Similarly, a recent challenge set, the WinoMT dataset (Stanovsky et al., 2019)
|
73 |
+
> shows that MT systems perform worse when they are asked to translate sentences
|
74 |
+
> that describe people with non-stereotypical gender roles, like "The doctor
|
75 |
+
> asked the nurse to help her in the > operation".
|
main.py → base_model/main.py
RENAMED
@@ -1,14 +1,15 @@
|
|
1 |
-
from
|
|
|
2 |
|
3 |
if __name__ == '__main__':
|
4 |
# Initialize retriever
|
5 |
r = Retriever()
|
6 |
|
7 |
# Retrieve example
|
8 |
-
|
9 |
-
"
|
10 |
|
11 |
-
for i,
|
12 |
-
print(f"Result {i+1} (score: {score
|
13 |
-
print(result['text'][
|
14 |
print() # Newline
|
|
|
1 |
+
from retriever import Retriever
|
2 |
+
|
3 |
|
4 |
if __name__ == '__main__':
|
5 |
# Initialize retriever
|
6 |
r = Retriever()
|
7 |
|
8 |
# Retrieve example
|
9 |
+
scores, result = r.retrieve(
|
10 |
+
"What is the perplexity of a language model?")
|
11 |
|
12 |
+
for i, score in enumerate(scores):
|
13 |
+
print(f"Result {i+1} (score: {score:.02f}):")
|
14 |
+
print(result['text'][i])
|
15 |
print() # Newline
|
base_model/retriever.py
CHANGED
@@ -1,10 +1,21 @@
|
|
1 |
-
from transformers import
|
2 |
-
|
|
|
|
|
|
|
|
|
3 |
from datasets import load_dataset
|
4 |
import torch
|
|
|
5 |
|
|
|
|
|
|
|
6 |
|
7 |
-
|
|
|
|
|
|
|
8 |
"""A class used to retrieve relevant documents based on some query.
|
9 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
10 |
"""
|
@@ -21,47 +32,64 @@ class Retriever():
|
|
21 |
|
22 |
# Context encoding and tokenization
|
23 |
self.ctx_encoder = DPRContextEncoder.from_pretrained(
|
24 |
-
"facebook/dpr-ctx_encoder-single-nq-base"
|
|
|
25 |
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
|
26 |
-
"facebook/dpr-ctx_encoder-single-nq-base"
|
|
|
27 |
|
28 |
# Question encoding and tokenization
|
29 |
self.q_encoder = DPRQuestionEncoder.from_pretrained(
|
30 |
-
"facebook/dpr-question_encoder-single-nq-base"
|
|
|
31 |
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
32 |
-
"facebook/dpr-question_encoder-single-nq-base"
|
|
|
33 |
|
34 |
# Dataset building
|
35 |
self.dataset = self.__init_dataset(dataset)
|
36 |
|
37 |
-
def __init_dataset(self,
|
|
|
|
|
38 |
"""Loads the dataset and adds FAISS embeddings.
|
39 |
|
40 |
Args:
|
41 |
dataset (str): A HuggingFace dataset name.
|
|
|
|
|
42 |
|
43 |
Returns:
|
44 |
Dataset: A dataset with a new column 'embeddings' containing FAISS
|
45 |
embeddings.
|
46 |
"""
|
47 |
-
# TODO: save ds w/ embeddings to disk and retrieve it if it already exists
|
48 |
-
|
49 |
# Load dataset
|
50 |
-
ds = load_dataset(dataset, name=
|
51 |
|
52 |
-
|
53 |
-
#
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
-
|
63 |
-
ds_with_embeddings.add_faiss_index(column='embeddings')
|
64 |
-
return ds_with_embeddings
|
65 |
|
66 |
def retrieve(self, query: str, k: int = 5):
|
67 |
"""Retrieve the top k matches for a search query.
|
@@ -77,10 +105,11 @@ class Retriever():
|
|
77 |
|
78 |
def embed(q):
|
79 |
# Inline helper function to perform embedding
|
80 |
-
tok = self.q_tokenizer(q, return_tensors=
|
81 |
return self.q_encoder(**tok)[0][0].numpy()
|
82 |
|
83 |
question_embedding = embed(query)
|
84 |
scores, results = self.dataset.get_nearest_examples(
|
85 |
-
|
|
|
86 |
return scores, results
|
|
|
1 |
+
from transformers import (
|
2 |
+
DPRContextEncoder,
|
3 |
+
DPRContextEncoderTokenizer,
|
4 |
+
DPRQuestionEncoder,
|
5 |
+
DPRQuestionEncoderTokenizer,
|
6 |
+
)
|
7 |
from datasets import load_dataset
|
8 |
import torch
|
9 |
+
import os.path
|
10 |
|
11 |
+
# Hacky fix for FAISS error on macOS
|
12 |
+
# See https://stackoverflow.com/a/63374568/4545692
|
13 |
+
import os
|
14 |
|
15 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
16 |
+
|
17 |
+
|
18 |
+
class Retriever:
|
19 |
"""A class used to retrieve relevant documents based on some query.
|
20 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
21 |
"""
|
|
|
32 |
|
33 |
# Context encoding and tokenization
|
34 |
self.ctx_encoder = DPRContextEncoder.from_pretrained(
|
35 |
+
"facebook/dpr-ctx_encoder-single-nq-base"
|
36 |
+
)
|
37 |
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
|
38 |
+
"facebook/dpr-ctx_encoder-single-nq-base"
|
39 |
+
)
|
40 |
|
41 |
# Question encoding and tokenization
|
42 |
self.q_encoder = DPRQuestionEncoder.from_pretrained(
|
43 |
+
"facebook/dpr-question_encoder-single-nq-base"
|
44 |
+
)
|
45 |
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
46 |
+
"facebook/dpr-question_encoder-single-nq-base"
|
47 |
+
)
|
48 |
|
49 |
# Dataset building
|
50 |
self.dataset = self.__init_dataset(dataset)
|
51 |
|
52 |
+
def __init_dataset(self,
|
53 |
+
dataset: str,
|
54 |
+
fname: str = "./models/paragraphs_embedding.faiss"):
|
55 |
"""Loads the dataset and adds FAISS embeddings.
|
56 |
|
57 |
Args:
|
58 |
dataset (str): A HuggingFace dataset name.
|
59 |
+
fname (str): The name to use to save the embeddings to disk for
|
60 |
+
faster loading after the first run.
|
61 |
|
62 |
Returns:
|
63 |
Dataset: A dataset with a new column 'embeddings' containing FAISS
|
64 |
embeddings.
|
65 |
"""
|
|
|
|
|
66 |
# Load dataset
|
67 |
+
ds = load_dataset(dataset, name="paragraphs")["train"]
|
68 |
|
69 |
+
if os.path.exists(fname):
|
70 |
+
# If we already have FAISS embeddings, load them from disk
|
71 |
+
ds.load_faiss_index('embeddings', fname)
|
72 |
+
return ds
|
73 |
+
else:
|
74 |
+
# If there are no FAISS embeddings, generate them
|
75 |
+
def embed(row):
|
76 |
+
# Inline helper function to perform embedding
|
77 |
+
p = row["text"]
|
78 |
+
tok = self.ctx_tokenizer(
|
79 |
+
p, return_tensors="pt", truncation=True)
|
80 |
+
enc = self.ctx_encoder(**tok)[0][0].numpy()
|
81 |
+
return {"embeddings": enc}
|
82 |
+
|
83 |
+
# Add FAISS embeddings
|
84 |
+
ds_with_embeddings = ds.map(embed)
|
85 |
+
|
86 |
+
ds_with_embeddings.add_faiss_index(column="embeddings")
|
87 |
|
88 |
+
# save dataset w/ embeddings
|
89 |
+
os.makedirs("./models/", exist_ok=True)
|
90 |
+
ds_with_embeddings.save_faiss_index("embeddings", fname)
|
91 |
|
92 |
+
return ds_with_embeddings
|
|
|
|
|
93 |
|
94 |
def retrieve(self, query: str, k: int = 5):
|
95 |
"""Retrieve the top k matches for a search query.
|
|
|
105 |
|
106 |
def embed(q):
|
107 |
# Inline helper function to perform embedding
|
108 |
+
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
109 |
return self.q_encoder(**tok)[0][0].numpy()
|
110 |
|
111 |
question_embedding = embed(query)
|
112 |
scores, results = self.dataset.get_nearest_examples(
|
113 |
+
"embeddings", question_embedding, k=k
|
114 |
+
)
|
115 |
return scores, results
|
poetry.lock
CHANGED
@@ -51,6 +51,18 @@ docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
|
|
51 |
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
|
52 |
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
[[package]]
|
55 |
name = "certifi"
|
56 |
version = "2021.10.8"
|
@@ -460,6 +472,14 @@ python-versions = "*"
|
|
460 |
docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
|
461 |
testing = ["pytest", "requests", "numpy", "datasets"]
|
462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
[[package]]
|
464 |
name = "torch"
|
465 |
version = "1.11.0"
|
@@ -590,7 +610,7 @@ multidict = ">=4.0"
|
|
590 |
[metadata]
|
591 |
lock-version = "1.1"
|
592 |
python-versions = "^3.8"
|
593 |
-
content-hash = "
|
594 |
|
595 |
[metadata.files]
|
596 |
aiohttp = [
|
@@ -679,6 +699,10 @@ attrs = [
|
|
679 |
{file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
|
680 |
{file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
|
681 |
]
|
|
|
|
|
|
|
|
|
682 |
certifi = [
|
683 |
{file = "certifi-2021.10.8-py2.py3-none-any.whl", hash = "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569"},
|
684 |
{file = "certifi-2021.10.8.tar.gz", hash = "sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872"},
|
@@ -1161,6 +1185,10 @@ tokenizers = [
|
|
1161 |
{file = "tokenizers-0.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b28966c68a2cdecd5120f4becea159eebe0335b8202e21e292eb381031026edc"},
|
1162 |
{file = "tokenizers-0.11.6.tar.gz", hash = "sha256:562b2022faf0882586c915385620d1f11798fc1b32bac55353a530132369a6d0"},
|
1163 |
]
|
|
|
|
|
|
|
|
|
1164 |
torch = [
|
1165 |
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
|
1166 |
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
|
|
|
51 |
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
|
52 |
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
|
53 |
|
54 |
+
[[package]]
|
55 |
+
name = "autopep8"
|
56 |
+
version = "1.6.0"
|
57 |
+
description = "A tool that automatically formats Python code to conform to the PEP 8 style guide"
|
58 |
+
category = "dev"
|
59 |
+
optional = false
|
60 |
+
python-versions = "*"
|
61 |
+
|
62 |
+
[package.dependencies]
|
63 |
+
pycodestyle = ">=2.8.0"
|
64 |
+
toml = "*"
|
65 |
+
|
66 |
[[package]]
|
67 |
name = "certifi"
|
68 |
version = "2021.10.8"
|
|
|
472 |
docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
|
473 |
testing = ["pytest", "requests", "numpy", "datasets"]
|
474 |
|
475 |
+
[[package]]
|
476 |
+
name = "toml"
|
477 |
+
version = "0.10.2"
|
478 |
+
description = "Python Library for Tom's Obvious, Minimal Language"
|
479 |
+
category = "dev"
|
480 |
+
optional = false
|
481 |
+
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
482 |
+
|
483 |
[[package]]
|
484 |
name = "torch"
|
485 |
version = "1.11.0"
|
|
|
610 |
[metadata]
|
611 |
lock-version = "1.1"
|
612 |
python-versions = "^3.8"
|
613 |
+
content-hash = "227b922ee14abf36ca75bb238d239d712bed9213d54c567996566d465e465733"
|
614 |
|
615 |
[metadata.files]
|
616 |
aiohttp = [
|
|
|
699 |
{file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
|
700 |
{file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
|
701 |
]
|
702 |
+
autopep8 = [
|
703 |
+
{file = "autopep8-1.6.0-py2.py3-none-any.whl", hash = "sha256:ed77137193bbac52d029a52c59bec1b0629b5a186c495f1eb21b126ac466083f"},
|
704 |
+
{file = "autopep8-1.6.0.tar.gz", hash = "sha256:44f0932855039d2c15c4510d6df665e4730f2b8582704fa48f9c55bd3e17d979"},
|
705 |
+
]
|
706 |
certifi = [
|
707 |
{file = "certifi-2021.10.8-py2.py3-none-any.whl", hash = "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569"},
|
708 |
{file = "certifi-2021.10.8.tar.gz", hash = "sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872"},
|
|
|
1185 |
{file = "tokenizers-0.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b28966c68a2cdecd5120f4becea159eebe0335b8202e21e292eb381031026edc"},
|
1186 |
{file = "tokenizers-0.11.6.tar.gz", hash = "sha256:562b2022faf0882586c915385620d1f11798fc1b32bac55353a530132369a6d0"},
|
1187 |
]
|
1188 |
+
toml = [
|
1189 |
+
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
1190 |
+
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
1191 |
+
]
|
1192 |
torch = [
|
1193 |
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
|
1194 |
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
|
pyproject.toml
CHANGED
@@ -14,6 +14,7 @@ faiss-cpu = "^1.7.2"
|
|
14 |
|
15 |
[tool.poetry.dev-dependencies]
|
16 |
flake8 = "^4.0.1"
|
|
|
17 |
|
18 |
[build-system]
|
19 |
requires = ["poetry-core>=1.0.0"]
|
|
|
14 |
|
15 |
[tool.poetry.dev-dependencies]
|
16 |
flake8 = "^4.0.1"
|
17 |
+
autopep8 = "^1.6.0"
|
18 |
|
19 |
[build-system]
|
20 |
requires = ["poetry-core>=1.0.0"]
|