bowdbeg commited on
Commit
ba888e1
1 Parent(s): fda183d
Files changed (3) hide show
  1. docred.py +155 -13
  2. official.py +171 -0
  3. sample.py +10 -0
docred.py CHANGED
@@ -13,9 +13,10 @@
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
- import evaluate
17
- import datasets
18
 
 
 
19
 
20
  # TODO: Add BibTeX citation
21
  _CITATION = """\
@@ -61,7 +62,30 @@ BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
61
  class docred(evaluate.Metric):
62
  """TODO: Short description of my evaluation module."""
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def _info(self):
 
65
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
66
  return evaluate.MetricInfo(
67
  # This is the description that will appear on the modules page.
@@ -70,15 +94,12 @@ class docred(evaluate.Metric):
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
77
  # Homepage of the module for documentation
78
  homepage="http://module.homepage",
79
  # Additional links to the codebase or references
80
  codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
82
  )
83
 
84
  def _download_and_prepare(self, dl_manager):
@@ -86,10 +107,131 @@ class docred(evaluate.Metric):
86
  # TODO: Download external resources if needed
87
  pass
88
 
89
- def _compute(self, predictions, references):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
+ import os
 
17
 
18
+ import datasets
19
+ import evaluate
20
 
21
  # TODO: Add BibTeX citation
22
  _CITATION = """\
 
62
  class docred(evaluate.Metric):
63
  """TODO: Short description of my evaluation module."""
64
 
65
+ dataset_feat = {
66
+ "title": datasets.Value("string"),
67
+ "sents": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
68
+ "vertexSet": datasets.Sequence(
69
+ datasets.Sequence(
70
+ {
71
+ "name": datasets.Value("string"),
72
+ "sent_id": datasets.Value("int32"),
73
+ "pos": datasets.Sequence(datasets.Value("int32"), length=2),
74
+ "type": datasets.Value("string"),
75
+ }
76
+ )
77
+ ),
78
+ "labels": {
79
+ "head": datasets.Sequence(datasets.Value("int32")),
80
+ "tail": datasets.Sequence(datasets.Value("int32")),
81
+ "relation_id": datasets.Sequence(datasets.Value("string")),
82
+ "relation_text": datasets.Sequence(datasets.Value("string")),
83
+ "evidence": datasets.Sequence(datasets.Sequence(datasets.Value("int32"))),
84
+ },
85
+ }
86
+
87
  def _info(self):
88
+
89
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
90
  return evaluate.MetricInfo(
91
  # This is the description that will appear on the modules page.
 
94
  citation=_CITATION,
95
  inputs_description=_KWARGS_DESCRIPTION,
96
  # This defines the format of each prediction and reference
97
+ features=datasets.Features({"predictions": self.dataset_feat, "references": self.dataset_feat}),
 
 
 
98
  # Homepage of the module for documentation
99
  homepage="http://module.homepage",
100
  # Additional links to the codebase or references
101
  codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
102
+ reference_urls=["http://path.to.reference.url/new_module"],
103
  )
104
 
105
  def _download_and_prepare(self, dl_manager):
 
107
  # TODO: Download external resources if needed
108
  pass
109
 
110
+ def _generate_fact(self, dataset):
111
+ if dataset is None:
112
+ return set()
113
+ facts = set()
114
+ for data in dataset:
115
+ vertexSet = data["vertexSet"]
116
+ labels = self._convert_labels_to_list(data["labels"])
117
+ for label in labels:
118
+ rel = label["relation_id"]
119
+ for n1 in vertexSet[label["head"]]["name"]:
120
+ for n2 in vertexSet[label["tail"]]["name"]:
121
+ facts.add((n1, n2, rel))
122
+ return facts
123
+
124
+ def _convert_to_relation_set(self, data):
125
+ relation_set = set()
126
+ for d in data:
127
+ labels = d["labels"]
128
+ labels = self._convert_labels_to_list(labels)
129
+ for label in labels:
130
+ relation_set.add((d["title"], label["head"], label["tail"], label["relation_id"]))
131
+ return relation_set
132
+
133
+ def _convert_labels_to_list(self, labels):
134
+ keys = list(labels.keys())
135
+ labels = [{key: labels[key][i] for key in keys} for i in range(len(labels[keys[0]]))]
136
+ return labels
137
+
138
+ def _compute(self, predictions, references, train_data=None):
139
  """Returns the scores"""
140
+
141
+ fact_in_train_annotated = self._generate_fact(train_data)
142
+
143
+ std = {}
144
+ tot_evidences = 0
145
+ ref_titleset = set([])
146
+
147
+ title2vectexSet = {}
148
+
149
+ for x in references:
150
+ title = x["title"]
151
+ ref_titleset.add(title)
152
+
153
+ vertexSet = x["vertexSet"]
154
+ title2vectexSet[title] = vertexSet
155
+ labels = self._convert_labels_to_list(x["labels"])
156
+ for label in labels:
157
+ r = label["relation_id"]
158
+ h_idx = label["head"]
159
+ t_idx = label["tail"]
160
+ std[(title, r, h_idx, t_idx)] = set(label["evidence"])
161
+ tot_evidences += len(label["evidence"])
162
+
163
+ tot_relations = len(std)
164
+ pred_rel = self._convert_to_relation_set(predictions)
165
+ submission_answer = sorted(pred_rel, key=lambda x: (x[0], x[1], x[2], x[3]))
166
+
167
+ correct_re = 0
168
+ correct_evidence = 0
169
+ pred_evi = 0
170
+
171
+ correct_in_train_annotated = 0
172
+ titleset2 = set([])
173
+ for x in submission_answer:
174
+ title, h_idx, t_idx, r = x
175
+ titleset2.add(title)
176
+ if title not in title2vectexSet:
177
+ continue
178
+ vertexSet = title2vectexSet[title]
179
+
180
+ if "evidence" in x:
181
+ evi = set(x["evidence"])
182
+ else:
183
+ evi = set([])
184
+ pred_evi += len(evi)
185
+
186
+ if (title, r, h_idx, t_idx) in std:
187
+ correct_re += 1
188
+ stdevi = std[(title, r, h_idx, t_idx)]
189
+ correct_evidence += len(stdevi & evi)
190
+ in_train_annotated = in_train_distant = False
191
+ for n1 in vertexSet[h_idx]["name"]:
192
+ for n2 in vertexSet[t_idx]["name"]:
193
+ if (n1, n2, r) in fact_in_train_annotated:
194
+ in_train_annotated = True
195
+
196
+ if in_train_annotated:
197
+ correct_in_train_annotated += 1
198
+ # if in_train_distant:
199
+ # correct_in_train_distant += 1
200
+
201
+ re_p = 1.0 * correct_re / (len(submission_answer) + 1e-5)
202
+ re_r = 1.0 * correct_re / (tot_relations + 1e-5)
203
+ if re_p + re_r == 0:
204
+ re_f1 = 0
205
+ else:
206
+ re_f1 = 2.0 * re_p * re_r / (re_p + re_r)
207
+
208
+ evi_p = 1.0 * correct_evidence / pred_evi if pred_evi > 0 else 0
209
+ evi_r = 1.0 * correct_evidence / tot_evidences
210
+ if evi_p + evi_r == 0:
211
+ evi_f1 = 0
212
+ else:
213
+ evi_f1 = 2.0 * evi_p * evi_r / (evi_p + evi_r)
214
+
215
+ re_p_ignore_train_annotated = (
216
+ 1.0
217
+ * (correct_re - correct_in_train_annotated)
218
+ / (len(submission_answer) - correct_in_train_annotated + 1e-5)
219
+ )
220
+ # re_p_ignore_train = (
221
+ # 1.0 * (correct_re - correct_in_train_distant) / (len(submission_answer) - correct_in_train_distant + 1e-5)
222
+ # )
223
+
224
+ if re_p_ignore_train_annotated + re_r == 0:
225
+ re_f1_ignore_train_annotated = 0
226
+ else:
227
+ re_f1_ignore_train_annotated = (
228
+ 2.0 * re_p_ignore_train_annotated * re_r / (re_p_ignore_train_annotated + re_r)
229
+ )
230
+
231
+ # if re_p_ignore_train + re_r == 0:
232
+ # re_f1_ignore_train = 0
233
+ # else:
234
+ # re_f1_ignore_train = 2.0 * re_p_ignore_train * re_r / (re_p_ignore_train + re_r)
235
+
236
+ # return re_f1, evi_f1, re_f1_ignore_train_annotated, re_f1_ignore_train, re_p, re_r
237
+ return {"f1": re_f1, "precision": re_p, "recall": re_r, "ign_f1": re_f1_ignore_train_annotated}
official.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import json
3
+ import os
4
+ import os.path
5
+ import sys
6
+
7
+
8
+ def gen_train_facts(data_file_name, truth_dir):
9
+ fact_file_name = data_file_name[data_file_name.find("train_") :]
10
+ fact_file_name = os.path.join(truth_dir, fact_file_name.replace(".json", ".fact"))
11
+
12
+ if os.path.exists(fact_file_name):
13
+ fact_in_train = set([])
14
+ triples = json.load(open(fact_file_name))
15
+ for x in triples:
16
+ fact_in_train.add(tuple(x))
17
+ return fact_in_train
18
+
19
+ fact_in_train = set([])
20
+ ori_data = json.load(open(data_file_name))
21
+ for data in ori_data:
22
+ vertexSet = data["vertexSet"]
23
+ for label in data["labels"]:
24
+ rel = label["r"]
25
+ for n1 in vertexSet[label["h"]]:
26
+ for n2 in vertexSet[label["t"]]:
27
+ fact_in_train.add((n1["name"], n2["name"], rel))
28
+
29
+ json.dump(list(fact_in_train), open(fact_file_name, "w"))
30
+
31
+ return fact_in_train
32
+
33
+
34
+ input_dir = sys.argv[1]
35
+ output_dir = sys.argv[2]
36
+
37
+ submit_dir = os.path.join(input_dir, "res")
38
+ truth_dir = os.path.join(input_dir, "ref")
39
+
40
+ if not os.path.isdir(submit_dir):
41
+ print("%s doesn't exist" % submit_dir)
42
+
43
+ if os.path.isdir(submit_dir) and os.path.isdir(truth_dir):
44
+ if not os.path.exists(output_dir):
45
+ os.makedirs(output_dir)
46
+
47
+ fact_in_train_annotated = gen_train_facts("../data/train_annotated.json", truth_dir)
48
+ fact_in_train_distant = gen_train_facts("../data/train_distant.json", truth_dir)
49
+
50
+ output_filename = os.path.join(output_dir, "scores.txt")
51
+ output_file = open(output_filename, "w")
52
+
53
+ truth_file = os.path.join(truth_dir, "dev_test.json")
54
+ truth = json.load(open(truth_file))
55
+
56
+ std = {}
57
+ tot_evidences = 0
58
+ titleset = set([])
59
+
60
+ title2vectexSet = {}
61
+
62
+ for x in truth:
63
+ title = x["title"]
64
+ titleset.add(title)
65
+
66
+ vertexSet = x["vertexSet"]
67
+ title2vectexSet[title] = vertexSet
68
+
69
+ for label in x["labels"]:
70
+ r = label["r"]
71
+
72
+ h_idx = label["h"]
73
+ t_idx = label["t"]
74
+ std[(title, r, h_idx, t_idx)] = set(label["evidence"])
75
+ tot_evidences += len(label["evidence"])
76
+
77
+ tot_relations = len(std)
78
+
79
+ submission_answer_file = os.path.join(submit_dir, "result.json")
80
+ tmp = json.load(open(submission_answer_file))
81
+ tmp.sort(key=lambda x: (x["title"], x["h_idx"], x["t_idx"], x["r"]))
82
+ submission_answer = [tmp[0]]
83
+ for i in range(1, len(tmp)):
84
+ x = tmp[i]
85
+ y = tmp[i - 1]
86
+ if (x["title"], x["h_idx"], x["t_idx"], x["r"]) != (y["title"], y["h_idx"], y["t_idx"], y["r"]):
87
+ submission_answer.append(tmp[i])
88
+
89
+ correct_re = 0
90
+ correct_evidence = 0
91
+ pred_evi = 0
92
+
93
+ correct_in_train_annotated = 0
94
+ correct_in_train_distant = 0
95
+ titleset2 = set([])
96
+ for x in submission_answer:
97
+ title = x["title"]
98
+ h_idx = x["h_idx"]
99
+ t_idx = x["t_idx"]
100
+ r = x["r"]
101
+ titleset2.add(title)
102
+ if title not in title2vectexSet:
103
+ continue
104
+ vertexSet = title2vectexSet[title]
105
+
106
+ if "evidence" in x:
107
+ evi = set(x["evidence"])
108
+ else:
109
+ evi = set([])
110
+ pred_evi += len(evi)
111
+
112
+ if (title, r, h_idx, t_idx) in std:
113
+ correct_re += 1
114
+ stdevi = std[(title, r, h_idx, t_idx)]
115
+ correct_evidence += len(stdevi & evi)
116
+ in_train_annotated = in_train_distant = False
117
+ for n1 in vertexSet[h_idx]:
118
+ for n2 in vertexSet[t_idx]:
119
+ if (n1["name"], n2["name"], r) in fact_in_train_annotated:
120
+ in_train_annotated = True
121
+ if (n1["name"], n2["name"], r) in fact_in_train_distant:
122
+ in_train_distant = True
123
+
124
+ if in_train_annotated:
125
+ correct_in_train_annotated += 1
126
+ if in_train_distant:
127
+ correct_in_train_distant += 1
128
+
129
+ re_p = 1.0 * correct_re / len(submission_answer)
130
+ re_r = 1.0 * correct_re / tot_relations
131
+ if re_p + re_r == 0:
132
+ re_f1 = 0
133
+ else:
134
+ re_f1 = 2.0 * re_p * re_r / (re_p + re_r)
135
+
136
+ evi_p = 1.0 * correct_evidence / pred_evi if pred_evi > 0 else 0
137
+ evi_r = 1.0 * correct_evidence / tot_evidences
138
+ if evi_p + evi_r == 0:
139
+ evi_f1 = 0
140
+ else:
141
+ evi_f1 = 2.0 * evi_p * evi_r / (evi_p + evi_r)
142
+
143
+ re_p_ignore_train_annotated = (
144
+ 1.0 * (correct_re - correct_in_train_annotated) / (len(submission_answer) - correct_in_train_annotated)
145
+ )
146
+ re_p_ignore_train = (
147
+ 1.0 * (correct_re - correct_in_train_distant) / (len(submission_answer) - correct_in_train_distant)
148
+ )
149
+
150
+ if re_p_ignore_train_annotated + re_r == 0:
151
+ re_f1_ignore_train_annotated = 0
152
+ else:
153
+ re_f1_ignore_train_annotated = 2.0 * re_p_ignore_train_annotated * re_r / (re_p_ignore_train_annotated + re_r)
154
+
155
+ if re_p_ignore_train + re_r == 0:
156
+ re_f1_ignore_train = 0
157
+ else:
158
+ re_f1_ignore_train = 2.0 * re_p_ignore_train * re_r / (re_p_ignore_train + re_r)
159
+
160
+ print("RE_F1:", re_f1)
161
+ print("Evi_F1:", evi_f1)
162
+ print("RE_ignore_annotated_F1:", re_f1_ignore_train_annotated)
163
+ print("RE_ignore_distant_F1:", re_f1_ignore_train)
164
+
165
+ output_file.write("RE_F1: %f\n" % re_f1)
166
+ output_file.write("Evi_F1: %f\n" % evi_f1)
167
+
168
+ output_file.write("RE_ignore_annotated_F1: %f\n" % re_f1_ignore_train_annotated)
169
+ output_file.write("RE_ignore_distant_F1: %f\n" % re_f1_ignore_train)
170
+
171
+ output_file.close()
sample.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import evaluate
3
+
4
+ from docred import docred
5
+
6
+ train_data = datasets.load_dataset("docred", split="train_annotated[:10]")
7
+ data = datasets.load_dataset("docred", split="validation[:10]")
8
+ metric = docred()
9
+
10
+ print(metric.compute(predictions=data.to_list(), references=data.to_list()))