GGroenendaal commited on
Commit
fa8dc75
β€’
2 Parent(s): 7570c1d 51a31d4

Merge branch 'esretriever' into main

Browse files
.env.example ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ELASTIC_USERNAME=elastic
2
+ ELASTIC_PASSWORD=<password>
3
+
4
+ LOG_LEVEL=INFO
README.md CHANGED
@@ -73,3 +73,6 @@ poetry run python main.py
73
  > shows that MT systems perform worse when they are asked to translate sentences
74
  > that describe people with non-stereotypical gender roles, like "The doctor
75
  > asked the nurse to help her in the > operation".
 
 
 
 
73
  > shows that MT systems perform worse when they are asked to translate sentences
74
  > that describe people with non-stereotypical gender roles, like "The doctor
75
  > asked the nurse to help her in the > operation".
76
+
77
+
78
+ ## Setting up elastic search.
base_model/main.py DELETED
@@ -1,20 +0,0 @@
1
- from retriever import Retriever
2
-
3
-
4
- if __name__ == '__main__':
5
- # Initialize retriever
6
- r = Retriever()
7
-
8
- # Retrieve example
9
- scores, result = r.retrieve(
10
- "What is the perplexity of a language model?")
11
-
12
- for i, score in enumerate(scores):
13
- print(f"Result {i+1} (score: {score:.02f}):")
14
- print(result['text'][i])
15
- print() # Newline
16
-
17
- # Compute overall performance
18
- exact_match, f1_score = r.evaluate()
19
- print(f"Exact match: {exact_match:.02f}\n"
20
- f"F1-score: {f1_score:.02f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import DatasetDict, load_dataset
2
+
3
+ from src.retrievers.fais_retriever import FAISRetriever
4
+ from src.utils.log import get_logger
5
+ from src.evaluation import evaluate
6
+ from typing import cast
7
+
8
+ logger = get_logger()
9
+
10
+
11
+ if __name__ == '__main__':
12
+ dataset_name = "GroNLP/ik-nlp-22_slp"
13
+ paragraphs = load_dataset(dataset_name, "paragraphs")
14
+ questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
15
+
16
+ questions_test = questions["test"]
17
+
18
+ logger.info(questions)
19
+
20
+ # Initialize retriever
21
+ r = FAISRetriever()
22
+
23
+ # # Retrieve example
24
+ example_q = "What is the perplexity of a language model?"
25
+ scores, result = r.retrieve(example_q)
26
+
27
+ logger.info(
28
+ f"Example q: {example_q} answer: {result['text'][0]}")
29
+
30
+ for i, score in enumerate(scores):
31
+ logger.info(f"Result {i+1} (score: {score:.02f}):")
32
+ logger.info(result['text'][i])
33
+
34
+ # Compute overall performance
35
+ exact_match, f1_score = evaluate(
36
+ r, questions_test["question"], questions_test["answer"])
37
+ logger.info(f"Exact match: {exact_match:.02f}\n"
38
+ f"F1-score: {f1_score:.02f}")
poetry.lock CHANGED
@@ -149,6 +149,36 @@ python-versions = ">=2.7, !=3.0.*"
149
  [package.extras]
150
  graph = ["objgraph (>=1.7.2)"]
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  [[package]]
153
  name = "faiss-cpu"
154
  version = "1.7.2"
@@ -291,6 +321,32 @@ python-versions = "*"
291
  [package.dependencies]
292
  dill = ">=0.3.4"
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  [[package]]
295
  name = "numpy"
296
  version = "1.22.3"
@@ -380,6 +436,17 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
380
  [package.dependencies]
381
  six = ">=1.5"
382
 
 
 
 
 
 
 
 
 
 
 
 
383
  [[package]]
384
  name = "pytz"
385
  version = "2021.3"
@@ -480,6 +547,14 @@ category = "dev"
480
  optional = false
481
  python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
482
 
 
 
 
 
 
 
 
 
483
  [[package]]
484
  name = "torch"
485
  version = "1.11.0"
@@ -610,7 +685,7 @@ multidict = ">=4.0"
610
  [metadata]
611
  lock-version = "1.1"
612
  python-versions = "^3.8"
613
- content-hash = "227b922ee14abf36ca75bb238d239d712bed9213d54c567996566d465e465733"
614
 
615
  [metadata.files]
616
  aiohttp = [
@@ -727,6 +802,14 @@ dill = [
727
  {file = "dill-0.3.4-py2.py3-none-any.whl", hash = "sha256:7e40e4a70304fd9ceab3535d36e58791d9c4a776b38ec7f7ec9afc8d3dca4d4f"},
728
  {file = "dill-0.3.4.zip", hash = "sha256:9f9734205146b2b353ab3fec9af0070237b6ddae78452af83d2fca84d739e675"},
729
  ]
 
 
 
 
 
 
 
 
730
  faiss-cpu = [
731
  {file = "faiss-cpu-1.7.2.tar.gz", hash = "sha256:f7ea89de997f55764e3710afaf0a457b2529252f99ee63510d4d9348d5b419dd"},
732
  {file = "faiss_cpu-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b7461f989d757917a3e6dc81eb171d0b563eb98d23ebaf7fc6684d0093ba267e"},
@@ -918,12 +1001,40 @@ multiprocess = [
918
  {file = "multiprocess-0.70.12.2-py39-none-any.whl", hash = "sha256:6f812a1d3f198b7cacd63983f60e2dc1338bd4450893f90c435067b5a3127e6f"},
919
  {file = "multiprocess-0.70.12.2.zip", hash = "sha256:206bb9b97b73f87fec1ed15a19f8762950256aa84225450abc7150d02855a083"},
920
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  numpy = [
922
  {file = "numpy-1.22.3-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:92bfa69cfbdf7dfc3040978ad09a48091143cffb778ec3b03fa170c494118d75"},
923
  {file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
924
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
925
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
926
- {file = "numpy-1.22.3-cp310-cp310-win32.whl", hash = "sha256:f950f8845b480cffe522913d35567e29dd381b0dc7e4ce6a4a9f9156417d2430"},
927
  {file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
928
  {file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
929
  {file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
@@ -1015,6 +1126,10 @@ python-dateutil = [
1015
  {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
1016
  {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
1017
  ]
 
 
 
 
1018
  pytz = [
1019
  {file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"},
1020
  {file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"},
@@ -1189,6 +1304,10 @@ toml = [
1189
  {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
1190
  {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
1191
  ]
 
 
 
 
1192
  torch = [
1193
  {file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
1194
  {file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
 
149
  [package.extras]
150
  graph = ["objgraph (>=1.7.2)"]
151
 
152
+ [[package]]
153
+ name = "elastic-transport"
154
+ version = "8.1.0"
155
+ description = "Transport classes and utilities shared among Python Elastic client libraries"
156
+ category = "main"
157
+ optional = false
158
+ python-versions = ">=3.6"
159
+
160
+ [package.dependencies]
161
+ certifi = "*"
162
+ urllib3 = ">=1.26.2,<2"
163
+
164
+ [package.extras]
165
+ develop = ["pytest", "pytest-cov", "pytest-mock", "pytest-asyncio", "mock", "requests", "aiohttp"]
166
+
167
+ [[package]]
168
+ name = "elasticsearch"
169
+ version = "8.1.0"
170
+ description = "Python client for Elasticsearch"
171
+ category = "main"
172
+ optional = false
173
+ python-versions = ">=3.6, <4"
174
+
175
+ [package.dependencies]
176
+ elastic-transport = ">=8,<9"
177
+
178
+ [package.extras]
179
+ async = ["aiohttp (>=3,<4)"]
180
+ requests = ["requests (>=2.4.0,<3.0.0)"]
181
+
182
  [[package]]
183
  name = "faiss-cpu"
184
  version = "1.7.2"
 
321
  [package.dependencies]
322
  dill = ">=0.3.4"
323
 
324
+ [[package]]
325
+ name = "mypy"
326
+ version = "0.941"
327
+ description = "Optional static typing for Python"
328
+ category = "dev"
329
+ optional = false
330
+ python-versions = ">=3.6"
331
+
332
+ [package.dependencies]
333
+ mypy-extensions = ">=0.4.3"
334
+ tomli = ">=1.1.0"
335
+ typing-extensions = ">=3.10"
336
+
337
+ [package.extras]
338
+ dmypy = ["psutil (>=4.0)"]
339
+ python2 = ["typed-ast (>=1.4.0,<2)"]
340
+ reports = ["lxml"]
341
+
342
+ [[package]]
343
+ name = "mypy-extensions"
344
+ version = "0.4.3"
345
+ description = "Experimental type system extensions for programs checked with the mypy typechecker."
346
+ category = "dev"
347
+ optional = false
348
+ python-versions = "*"
349
+
350
  [[package]]
351
  name = "numpy"
352
  version = "1.22.3"
 
436
  [package.dependencies]
437
  six = ">=1.5"
438
 
439
+ [[package]]
440
+ name = "python-dotenv"
441
+ version = "0.19.2"
442
+ description = "Read key-value pairs from a .env file and set them as environment variables"
443
+ category = "main"
444
+ optional = false
445
+ python-versions = ">=3.5"
446
+
447
+ [package.extras]
448
+ cli = ["click (>=5.0)"]
449
+
450
  [[package]]
451
  name = "pytz"
452
  version = "2021.3"
 
547
  optional = false
548
  python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
549
 
550
+ [[package]]
551
+ name = "tomli"
552
+ version = "2.0.1"
553
+ description = "A lil' TOML parser"
554
+ category = "dev"
555
+ optional = false
556
+ python-versions = ">=3.7"
557
+
558
  [[package]]
559
  name = "torch"
560
  version = "1.11.0"
 
685
  [metadata]
686
  lock-version = "1.1"
687
  python-versions = "^3.8"
688
+ content-hash = "7fadbb5aabac268ecd27c257e2c8f651d26896e78c9cc0ea7e61a8b6ec61c84c"
689
 
690
  [metadata.files]
691
  aiohttp = [
 
802
  {file = "dill-0.3.4-py2.py3-none-any.whl", hash = "sha256:7e40e4a70304fd9ceab3535d36e58791d9c4a776b38ec7f7ec9afc8d3dca4d4f"},
803
  {file = "dill-0.3.4.zip", hash = "sha256:9f9734205146b2b353ab3fec9af0070237b6ddae78452af83d2fca84d739e675"},
804
  ]
805
+ elastic-transport = [
806
+ {file = "elastic-transport-8.1.0.tar.gz", hash = "sha256:769ee4c7b28d270cdbce71359973b88129ac312b13be95b4f7479e35c49d9455"},
807
+ {file = "elastic_transport-8.1.0-py3-none-any.whl", hash = "sha256:0bb2ae3d13348e9e4587ca1f17cd813a528a7cc07f879505f56d69c81823b660"},
808
+ ]
809
+ elasticsearch = [
810
+ {file = "elasticsearch-8.1.0-py3-none-any.whl", hash = "sha256:11e36565dfdf649b7911c2d3cb1f15b99267acfb7f82e94e7613c0323a9936e9"},
811
+ {file = "elasticsearch-8.1.0.tar.gz", hash = "sha256:648d1c707a632279535356d2762cbc63ae728c4633211fe160f43f87a3e1cdcd"},
812
+ ]
813
  faiss-cpu = [
814
  {file = "faiss-cpu-1.7.2.tar.gz", hash = "sha256:f7ea89de997f55764e3710afaf0a457b2529252f99ee63510d4d9348d5b419dd"},
815
  {file = "faiss_cpu-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b7461f989d757917a3e6dc81eb171d0b563eb98d23ebaf7fc6684d0093ba267e"},
 
1001
  {file = "multiprocess-0.70.12.2-py39-none-any.whl", hash = "sha256:6f812a1d3f198b7cacd63983f60e2dc1338bd4450893f90c435067b5a3127e6f"},
1002
  {file = "multiprocess-0.70.12.2.zip", hash = "sha256:206bb9b97b73f87fec1ed15a19f8762950256aa84225450abc7150d02855a083"},
1003
  ]
1004
+ mypy = [
1005
+ {file = "mypy-0.941-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:98f61aad0bb54f797b17da5b82f419e6ce214de0aa7e92211ebee9e40eb04276"},
1006
+ {file = "mypy-0.941-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6a8e1f63357851444940351e98fb3252956a15f2cabe3d698316d7a2d1f1f896"},
1007
+ {file = "mypy-0.941-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b30d29251dff4c59b2e5a1fa1bab91ff3e117b4658cb90f76d97702b7a2ae699"},
1008
+ {file = "mypy-0.941-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8eaf55fdf99242a1c8c792247c455565447353914023878beadb79600aac4a2a"},
1009
+ {file = "mypy-0.941-cp310-cp310-win_amd64.whl", hash = "sha256:080097eee5393fd740f32c63f9343580aaa0fb1cda0128fd859dfcf081321c3d"},
1010
+ {file = "mypy-0.941-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f79137d012ff3227866222049af534f25354c07a0d6b9a171dba9f1d6a1fdef4"},
1011
+ {file = "mypy-0.941-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8e5974583a77d630a5868eee18f85ac3093caf76e018c510aeb802b9973304ce"},
1012
+ {file = "mypy-0.941-cp36-cp36m-win_amd64.whl", hash = "sha256:0dd441fbacf48e19dc0c5c42fafa72b8e1a0ba0a39309c1af9c84b9397d9b15a"},
1013
+ {file = "mypy-0.941-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:0d3bcbe146247997e03bf030122000998b076b3ac6925b0b6563f46d1ce39b50"},
1014
+ {file = "mypy-0.941-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3bada0cf7b6965627954b3a128903a87cac79a79ccd83b6104912e723ef16c7b"},
1015
+ {file = "mypy-0.941-cp37-cp37m-win_amd64.whl", hash = "sha256:eea10982b798ff0ccc3b9e7e42628f932f552c5845066970e67cd6858655d52c"},
1016
+ {file = "mypy-0.941-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:108f3c7e14a038cf097d2444fa0155462362c6316e3ecb2d70f6dd99cd36084d"},
1017
+ {file = "mypy-0.941-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d61b73c01fc1de799226963f2639af831307fe1556b04b7c25e2b6c267a3bc76"},
1018
+ {file = "mypy-0.941-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:42c216a33d2bdba08098acaf5bae65b0c8196afeb535ef4b870919a788a27259"},
1019
+ {file = "mypy-0.941-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fc5ecff5a3bbfbe20091b1cad82815507f5ae9c380a3a9bf40f740c70ce30a9b"},
1020
+ {file = "mypy-0.941-cp38-cp38-win_amd64.whl", hash = "sha256:bf446223b2e0e4f0a4792938e8d885e8a896834aded5f51be5c3c69566495540"},
1021
+ {file = "mypy-0.941-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:745071762f32f65e77de6df699366d707fad6c132a660d1342077cbf671ef589"},
1022
+ {file = "mypy-0.941-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:465a6ce9ca6268cadfbc27a2a94ddf0412568a6b27640ced229270be4f5d394d"},
1023
+ {file = "mypy-0.941-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d051ce0946521eba48e19b25f27f98e5ce4dbc91fff296de76240c46b4464df0"},
1024
+ {file = "mypy-0.941-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:818cfc51c25a5dbfd0705f3ac1919fff6971eb0c02e6f1a1f6a017a42405a7c0"},
1025
+ {file = "mypy-0.941-cp39-cp39-win_amd64.whl", hash = "sha256:b2ce2788df0c066c2ff4ba7190fa84f18937527c477247e926abeb9b1168b8cc"},
1026
+ {file = "mypy-0.941-py3-none-any.whl", hash = "sha256:3cf77f138efb31727ee7197bc824c9d6d7039204ed96756cc0f9ca7d8e8fc2a4"},
1027
+ {file = "mypy-0.941.tar.gz", hash = "sha256:cbcc691d8b507d54cb2b8521f0a2a3d4daa477f62fe77f0abba41e5febb377b7"},
1028
+ ]
1029
+ mypy-extensions = [
1030
+ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},
1031
+ {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"},
1032
+ ]
1033
  numpy = [
1034
  {file = "numpy-1.22.3-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:92bfa69cfbdf7dfc3040978ad09a48091143cffb778ec3b03fa170c494118d75"},
1035
  {file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
1036
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
1037
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
 
1038
  {file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
1039
  {file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
1040
  {file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
 
1126
  {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
1127
  {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
1128
  ]
1129
+ python-dotenv = [
1130
+ {file = "python-dotenv-0.19.2.tar.gz", hash = "sha256:a5de49a31e953b45ff2d2fd434bbc2670e8db5273606c1e737cc6b93eff3655f"},
1131
+ {file = "python_dotenv-0.19.2-py2.py3-none-any.whl", hash = "sha256:32b2bdc1873fd3a3c346da1c6db83d0053c3c62f28f1f38516070c4c8971b1d3"},
1132
+ ]
1133
  pytz = [
1134
  {file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"},
1135
  {file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"},
 
1304
  {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
1305
  {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
1306
  ]
1307
+ tomli = [
1308
+ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
1309
+ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
1310
+ ]
1311
  torch = [
1312
  {file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
1313
  {file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
pyproject.toml CHANGED
@@ -11,10 +11,28 @@ transformers = "^4.17.0"
11
  torch = "^1.11.0"
12
  datasets = "^1.18.4"
13
  faiss-cpu = "^1.7.2"
 
 
14
 
15
  [tool.poetry.dev-dependencies]
16
  flake8 = "^4.0.1"
17
  autopep8 = "^1.6.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  [build-system]
20
  requires = ["poetry-core>=1.0.0"]
 
11
  torch = "^1.11.0"
12
  datasets = "^1.18.4"
13
  faiss-cpu = "^1.7.2"
14
+ python-dotenv = "^0.19.2"
15
+ elasticsearch = "^8.1.0"
16
 
17
  [tool.poetry.dev-dependencies]
18
  flake8 = "^4.0.1"
19
  autopep8 = "^1.6.0"
20
+ mypy = "^0.941"
21
+
22
+ [tool.mypy]
23
+ no_implicit_optional=true
24
+
25
+ [[tool.mypy.overrides]]
26
+ module = [
27
+ "transformers",
28
+ "datasets",
29
+ ]
30
+ ignore_missing_imports = true
31
+
32
+
33
+ [tool.isort]
34
+ profile = "black"
35
+
36
 
37
  [build-system]
38
  requires = ["poetry-core>=1.0.0"]
base_model/evaluate.py β†’ src/evaluation.py RENAMED
@@ -1,15 +1,17 @@
1
- from typing import Callable, List
 
2
 
3
- from base_model.string_utils import lower, remove_articles, remove_punc, white_space_fix
 
4
 
5
 
6
- def normalize_text(inp: str, preprocessing_functions: List[Callable[[str], str]]):
7
  for fun in preprocessing_functions:
8
  inp = fun(inp)
9
  return inp
10
 
11
 
12
- def normalize_text_default(inp: str) -> str:
13
  """Preprocesses the sentence string by normalizing.
14
 
15
  Args:
@@ -21,10 +23,10 @@ def normalize_text_default(inp: str) -> str:
21
 
22
  steps = [remove_articles, white_space_fix, remove_punc, lower]
23
 
24
- return normalize_text(inp, steps)
25
 
26
 
27
- def compute_exact_match(prediction: str, answer: str) -> int:
28
  """Computes exact match for sentences.
29
 
30
  Args:
@@ -34,10 +36,10 @@ def compute_exact_match(prediction: str, answer: str) -> int:
34
  Returns:
35
  int: 1 for exact match, 0 for not
36
  """
37
- return int(normalize_text_default(prediction) == normalize_text_default(answer))
38
 
39
 
40
- def compute_f1(prediction: str, answer: str) -> float:
41
  """Computes F1-score on token overlap for sentences.
42
 
43
  Args:
@@ -47,8 +49,8 @@ def compute_f1(prediction: str, answer: str) -> float:
47
  Returns:
48
  boolean: the f1 score
49
  """
50
- pred_tokens = normalize_text_default(prediction).split()
51
- answer_tokens = normalize_text_default(answer).split()
52
 
53
  if len(pred_tokens) == 0 or len(answer_tokens) == 0:
54
  return int(pred_tokens == answer_tokens)
@@ -62,3 +64,29 @@ def compute_f1(prediction: str, answer: str) -> float:
62
  rec = len(common_tokens) / len(answer_tokens)
63
 
64
  return 2 * (prec * rec) / (prec + rec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, List
2
+ from src.retrievers.base_retriever import Retriever
3
 
4
+ from src.utils.string_utils import (lower, remove_articles, remove_punc,
5
+ white_space_fix)
6
 
7
 
8
+ def _normalize_text(inp: str, preprocessing_functions: List[Callable[[str], str]]):
9
  for fun in preprocessing_functions:
10
  inp = fun(inp)
11
  return inp
12
 
13
 
14
+ def _normalize_text_default(inp: str) -> str:
15
  """Preprocesses the sentence string by normalizing.
16
 
17
  Args:
 
23
 
24
  steps = [remove_articles, white_space_fix, remove_punc, lower]
25
 
26
+ return _normalize_text(inp, steps)
27
 
28
 
29
+ def exact_match(prediction: str, answer: str) -> int:
30
  """Computes exact match for sentences.
31
 
32
  Args:
 
36
  Returns:
37
  int: 1 for exact match, 0 for not
38
  """
39
+ return int(_normalize_text_default(prediction) == _normalize_text_default(answer))
40
 
41
 
42
+ def f1(prediction: str, answer: str) -> float:
43
  """Computes F1-score on token overlap for sentences.
44
 
45
  Args:
 
49
  Returns:
50
  boolean: the f1 score
51
  """
52
+ pred_tokens = _normalize_text_default(prediction).split()
53
+ answer_tokens = _normalize_text_default(answer).split()
54
 
55
  if len(pred_tokens) == 0 or len(answer_tokens) == 0:
56
  return int(pred_tokens == answer_tokens)
 
64
  rec = len(common_tokens) / len(answer_tokens)
65
 
66
  return 2 * (prec * rec) / (prec + rec)
67
+
68
+
69
+ def evaluate(retriever: Retriever, questions: Any, answers: Any):
70
+ """Evaluates the entire model by computing F1-score and exact match on the
71
+ entire dataset.
72
+
73
+ Returns:
74
+ float: overall exact match
75
+ float: overall F1-score
76
+ """
77
+
78
+ predictions = []
79
+ scores = 0
80
+
81
+ # Currently just takes the first answer and does not look at scores yet
82
+ for question in questions:
83
+ score, result = retriever.retrieve(question, 1)
84
+ scores += score[0]
85
+ predictions.append(result['text'][0])
86
+
87
+ exact_matches = [exact_match(
88
+ predictions[i], answers[i]) for i in range(len(answers))]
89
+ f1_scores = [f1(
90
+ predictions[i], answers[i]) for i in range(len(answers))]
91
+
92
+ return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
{base_model β†’ src}/reader.py RENAMED
File without changes
src/retrievers/base_retriever.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ class Retriever():
2
+ def retrieve(self, query: str, k: int):
3
+ pass
src/retrievers/es_retriever.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.log import get_logger
2
+
3
+ logger = get_logger()
4
+
5
+
6
+ class ESRetriever(Retriever):
7
+ def __init__(self, data_set):
8
+ pass
9
+
10
+ def retrieve(self, query: str, k: int):
11
+ pass
base_model/retriever.py β†’ src/retrievers/fais_retriever.py RENAMED
@@ -1,23 +1,27 @@
 
 
 
 
 
1
  from transformers import (
2
  DPRContextEncoder,
3
  DPRContextEncoderTokenizer,
4
  DPRQuestionEncoder,
5
  DPRQuestionEncoderTokenizer,
6
  )
7
- from datasets import load_dataset
8
- import torch
9
- import os.path
10
 
11
- import evaluate
 
12
 
 
13
  # Hacky fix for FAISS error on macOS
14
  # See https://stackoverflow.com/a/63374568/4545692
15
- import os
16
 
17
- os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
 
18
 
19
 
20
- class Retriever:
21
  """A class used to retrieve relevant documents based on some query.
22
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
23
  """
@@ -67,12 +71,13 @@ class Retriever:
67
  embeddings.
68
  """
69
  # Load dataset
70
- ds = load_dataset(dataset_name, name="paragraphs")["train"]
71
- print(ds)
 
72
 
73
  if os.path.exists(embedding_path):
74
  # If we already have FAISS embeddings, load them from disk
75
- ds.load_faiss_index('embeddings', embedding_path)
76
  return ds
77
  else:
78
  # If there are no FAISS embeddings, generate them
@@ -85,7 +90,7 @@ class Retriever:
85
  return {"embeddings": enc}
86
 
87
  # Add FAISS embeddings
88
- ds_with_embeddings = ds.map(embed)
89
 
90
  ds_with_embeddings.add_faiss_index(column="embeddings")
91
 
@@ -118,32 +123,3 @@ class Retriever:
118
  )
119
 
120
  return scores, results
121
-
122
- def evaluate(self):
123
- """Evaluates the entire model by computing F1-score and exact match on the
124
- entire dataset.
125
-
126
- Returns:
127
- float: overall exact match
128
- float: overall F1-score
129
- """
130
- questions_ds = load_dataset(
131
- self.dataset_name, name="questions")['test']
132
- questions = questions_ds['question']
133
- answers = questions_ds['answer']
134
-
135
- predictions = []
136
- scores = 0
137
-
138
- # Currently just takes the first answer and does not look at scores yet
139
- for question in questions:
140
- score, result = self.retrieve(question, 1)
141
- scores += score[0]
142
- predictions.append(result['text'][0])
143
-
144
- exact_matches = [evaluate.compute_exact_match(
145
- predictions[i], answers[i]) for i in range(len(answers))]
146
- f1_scores = [evaluate.compute_f1(
147
- predictions[i], answers[i]) for i in range(len(answers))]
148
-
149
- return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
 
1
+ import os
2
+ import os.path
3
+
4
+ import torch
5
+ from datasets import load_dataset
6
  from transformers import (
7
  DPRContextEncoder,
8
  DPRContextEncoderTokenizer,
9
  DPRQuestionEncoder,
10
  DPRQuestionEncoderTokenizer,
11
  )
 
 
 
12
 
13
+ from src.retrievers.base_retriever import Retriever
14
+ from src.utils.log import get_logger
15
 
16
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
17
  # Hacky fix for FAISS error on macOS
18
  # See https://stackoverflow.com/a/63374568/4545692
 
19
 
20
+
21
+ logger = get_logger()
22
 
23
 
24
+ class FAISRetriever(Retriever):
25
  """A class used to retrieve relevant documents based on some query.
26
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
27
  """
 
71
  embeddings.
72
  """
73
  # Load dataset
74
+ ds = load_dataset(dataset_name, name="paragraphs")[
75
+ "train"] # type: ignore
76
+ logger.info(ds)
77
 
78
  if os.path.exists(embedding_path):
79
  # If we already have FAISS embeddings, load them from disk
80
+ ds.load_faiss_index('embeddings', embedding_path) # type: ignore
81
  return ds
82
  else:
83
  # If there are no FAISS embeddings, generate them
 
90
  return {"embeddings": enc}
91
 
92
  # Add FAISS embeddings
93
+ ds_with_embeddings = ds.map(embed) # type: ignore
94
 
95
  ds_with_embeddings.add_faiss_index(column="embeddings")
96
 
 
123
  )
124
 
125
  return scores, results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/log.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+
9
+ def get_logger():
10
+ # creates a default logger for the project
11
+ logger = logging.getLogger("Flashcards")
12
+
13
+ log_level = os.getenv("LOG_LEVEL", "INFO")
14
+ logger.setLevel(log_level)
15
+
16
+ # Log format
17
+ formatter = logging.Formatter(
18
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
19
+
20
+ # file handler
21
+ fh = logging.FileHandler("logs.log")
22
+ fh.setFormatter(formatter)
23
+
24
+ # stout
25
+ ch = logging.StreamHandler()
26
+ ch.setFormatter(formatter)
27
+
28
+ logger.addHandler(fh)
29
+ logger.addHandler(ch)
30
+
31
+ return logger
{base_model β†’ src/utils}/string_utils.py RENAMED
File without changes