Robert commited on
Commit
7570c1d
2 Parent(s): 8fe5a80 aa426fb

Merge pull request #1

Browse files
base_model/evaluate.py CHANGED
@@ -1,29 +1,27 @@
1
- def normalize_text(s: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
2
  """Preprocesses the sentence string by normalizing.
3
 
4
  Args:
5
  s (str): the sentence
6
 
7
  Returns:
8
- string: normalized sentence
9
  """
10
- import string, re
11
-
12
- def remove_articles(text):
13
- regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
14
- return re.sub(regex, " ", text)
15
-
16
- def white_space_fix(text):
17
- return " ".join(text.split())
18
-
19
- def remove_punc(text):
20
- exclude = set(string.punctuation)
21
- return "".join(ch for ch in text if ch not in exclude)
22
 
23
- def lower(text):
24
- return text.lower()
25
 
26
- return white_space_fix(remove_articles(remove_punc(lower(s))))
27
 
28
 
29
  def compute_exact_match(prediction: str, answer: str) -> int:
@@ -36,7 +34,7 @@ def compute_exact_match(prediction: str, answer: str) -> int:
36
  Returns:
37
  int: 1 for exact match, 0 for not
38
  """
39
- return int(normalize_text(prediction) == normalize_text(answer))
40
 
41
 
42
  def compute_f1(prediction: str, answer: str) -> float:
@@ -49,8 +47,8 @@ def compute_f1(prediction: str, answer: str) -> float:
49
  Returns:
50
  boolean: the f1 score
51
  """
52
- pred_tokens = normalize_text(prediction).split()
53
- answer_tokens = normalize_text(answer).split()
54
 
55
  if len(pred_tokens) == 0 or len(answer_tokens) == 0:
56
  return int(pred_tokens == answer_tokens)
 
1
+ from typing import Callable, List
2
+
3
+ from base_model.string_utils import lower, remove_articles, remove_punc, white_space_fix
4
+
5
+
6
+ def normalize_text(inp: str, preprocessing_functions: List[Callable[[str], str]]):
7
+ for fun in preprocessing_functions:
8
+ inp = fun(inp)
9
+ return inp
10
+
11
+
12
+ def normalize_text_default(inp: str) -> str:
13
  """Preprocesses the sentence string by normalizing.
14
 
15
  Args:
16
  s (str): the sentence
17
 
18
  Returns:
19
+ string: normalized with default parames
20
  """
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ steps = [remove_articles, white_space_fix, remove_punc, lower]
 
23
 
24
+ return normalize_text(inp, steps)
25
 
26
 
27
  def compute_exact_match(prediction: str, answer: str) -> int:
 
34
  Returns:
35
  int: 1 for exact match, 0 for not
36
  """
37
+ return int(normalize_text_default(prediction) == normalize_text_default(answer))
38
 
39
 
40
  def compute_f1(prediction: str, answer: str) -> float:
 
47
  Returns:
48
  boolean: the f1 score
49
  """
50
+ pred_tokens = normalize_text_default(prediction).split()
51
+ answer_tokens = normalize_text_default(answer).split()
52
 
53
  if len(pred_tokens) == 0 or len(answer_tokens) == 0:
54
  return int(pred_tokens == answer_tokens)
base_model/retriever.py CHANGED
@@ -22,7 +22,7 @@ class Retriever:
22
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
23
  """
24
 
25
- def __init__(self, dataset: str = "GroNLP/ik-nlp-22_slp") -> None:
26
  """Initialize the retriever
27
 
28
  Args:
@@ -49,12 +49,12 @@ class Retriever:
49
  )
50
 
51
  # Dataset building
52
- self.dataset = self.__init_dataset(dataset)
 
53
 
54
-
55
- def __init_dataset(self,
56
- dataset: str,
57
- fname: str = "./models/paragraphs_embedding.faiss"):
58
  """Loads the dataset and adds FAISS embeddings.
59
 
60
  Args:
@@ -67,12 +67,12 @@ class Retriever:
67
  embeddings.
68
  """
69
  # Load dataset
70
- ds = load_dataset(dataset, name="paragraphs")["train"]
71
  print(ds)
72
 
73
- if os.path.exists(fname):
74
  # If we already have FAISS embeddings, load them from disk
75
- ds.load_faiss_index('embeddings', fname)
76
  return ds
77
  else:
78
  # If there are no FAISS embeddings, generate them
@@ -91,7 +91,7 @@ class Retriever:
91
 
92
  # save dataset w/ embeddings
93
  os.makedirs("./models/", exist_ok=True)
94
- ds_with_embeddings.save_faiss_index("embeddings", fname)
95
 
96
  return ds_with_embeddings
97
 
@@ -127,7 +127,8 @@ class Retriever:
127
  float: overall exact match
128
  float: overall F1-score
129
  """
130
- questions_ds = load_dataset("GroNLP/ik-nlp-22_slp", name="questions")['test']
 
131
  questions = questions_ds['question']
132
  answers = questions_ds['answer']
133
 
@@ -140,7 +141,9 @@ class Retriever:
140
  scores += score[0]
141
  predictions.append(result['text'][0])
142
 
143
- exact_matches = [evaluate.compute_exact_match(predictions[i], answers[i]) for i in range(len(answers))]
144
- f1_scores = [evaluate.compute_f1(predictions[i], answers[i]) for i in range(len(answers))]
 
 
145
 
146
  return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
 
22
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
23
  """
24
 
25
+ def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp") -> None:
26
  """Initialize the retriever
27
 
28
  Args:
 
49
  )
50
 
51
  # Dataset building
52
+ self.dataset_name = dataset_name
53
+ self.dataset = self._init_dataset(dataset_name)
54
 
55
+ def _init_dataset(self,
56
+ dataset_name: str,
57
+ embedding_path: str = "./models/paragraphs_embedding.faiss"):
 
58
  """Loads the dataset and adds FAISS embeddings.
59
 
60
  Args:
 
67
  embeddings.
68
  """
69
  # Load dataset
70
+ ds = load_dataset(dataset_name, name="paragraphs")["train"]
71
  print(ds)
72
 
73
+ if os.path.exists(embedding_path):
74
  # If we already have FAISS embeddings, load them from disk
75
+ ds.load_faiss_index('embeddings', embedding_path)
76
  return ds
77
  else:
78
  # If there are no FAISS embeddings, generate them
 
91
 
92
  # save dataset w/ embeddings
93
  os.makedirs("./models/", exist_ok=True)
94
+ ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
95
 
96
  return ds_with_embeddings
97
 
 
127
  float: overall exact match
128
  float: overall F1-score
129
  """
130
+ questions_ds = load_dataset(
131
+ self.dataset_name, name="questions")['test']
132
  questions = questions_ds['question']
133
  answers = questions_ds['answer']
134
 
 
141
  scores += score[0]
142
  predictions.append(result['text'][0])
143
 
144
+ exact_matches = [evaluate.compute_exact_match(
145
+ predictions[i], answers[i]) for i in range(len(answers))]
146
+ f1_scores = [evaluate.compute_f1(
147
+ predictions[i], answers[i]) for i in range(len(answers))]
148
 
149
  return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
base_model/string_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+
4
+
5
+ def remove_articles(text):
6
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
7
+ return re.sub(regex, " ", text)
8
+
9
+
10
+ def white_space_fix(text):
11
+ return " ".join(text.split())
12
+
13
+
14
+ def remove_punc(text):
15
+ exclude = set(string.punctuation)
16
+ return "".join(ch for ch in text if ch not in exclude)
17
+
18
+
19
+ def lower(text):
20
+ return text.lower()