danieldux commited on
Commit
3f5c862
1 Parent(s): 78950cf

Move functions to main class

Browse files
Files changed (1) hide show
  1. isco_hierarchical_accuracy.py +140 -5
isco_hierarchical_accuracy.py CHANGED
@@ -13,10 +13,12 @@
13
  # limitations under the License.
14
  """ISCO-08 Hierarchical Accuracy Measure."""
15
 
 
16
  import evaluate
17
  import datasets
18
- import ham
19
- import isco
 
20
 
21
 
22
  # TODO: Add BibTeX citation
@@ -110,11 +112,144 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
110
  reference_urls=["http://path.to.reference.url/new_module"],
111
  )
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def _download_and_prepare(self, dl_manager):
114
  """Download external ISCO-08 csv file from the ILO website for creating the hierarchy dictionary."""
115
  isco_csv = dl_manager.download_and_extract(ISCO_CSV_MIRROR_URL)
116
  print(f"ISCO CSV file downloaded")
117
- self.isco_hierarchy = isco.create_hierarchy_dict(isco_csv)
 
118
  print("ISCO hierarchy dictionary created")
119
  print(self.isco_hierarchy)
120
 
@@ -132,10 +267,10 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
132
 
133
  # Calculate hierarchical precision, recall and f-measure
134
  hierarchy = self.isco_hierarchy
135
- hP, hR = ham.calculate_hierarchical_precision_recall(
136
  references, predictions, hierarchy
137
  )
138
- hF = ham.hierarchical_f_measure(hP, hR)
139
  print(
140
  f"Hierarchical Precision: {hP}, Hierarchical Recall: {hR}, Hierarchical F-measure: {hF}"
141
  )
 
13
  # limitations under the License.
14
  """ISCO-08 Hierarchical Accuracy Measure."""
15
 
16
+ from typing import List, Set, Dict, Tuple
17
  import evaluate
18
  import datasets
19
+
20
+ # import ham
21
+ # import isco
22
 
23
 
24
  # TODO: Add BibTeX citation
 
112
  reference_urls=["http://path.to.reference.url/new_module"],
113
  )
114
 
115
+ def create_hierarchy_dict(file: str) -> dict:
116
+ """
117
+ Creates a dictionary where keys are nodes and values are sets of parent nodes representing the group level hierarchy of the ISCO-08 structure.
118
+ The function assumes that the input CSV file has a column named 'unit' with the 4-digit ISCO-08 codes.
119
+ A csv file with the ISCO-08 structure can be downloaded from the International Labour Organization (ILO) at [https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08 EN.csv](https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08%20EN.csv)
120
+
121
+ Args:
122
+ - file: A string representing the path to the CSV file containing the 4-digit ISCO-08 codes. It can be a local path or a web URL.
123
+
124
+ Returns:
125
+ - A dictionary where keys are ISCO-08 unit codes and values are sets of their parent codes.
126
+ """
127
+
128
+ try:
129
+ import requests
130
+ import csv
131
+ except ImportError as error:
132
+ raise error
133
+
134
+ isco_hierarchy = {}
135
+
136
+ if file.startswith("http://") or file.startswith("https://"):
137
+ response = requests.get(file)
138
+ lines = response.text.splitlines()
139
+ else:
140
+ with open(file, newline="") as csvfile:
141
+ lines = csvfile.readlines()
142
+
143
+ reader = csv.DictReader(lines)
144
+ for row in reader:
145
+ unit_code = row["unit"].zfill(4)
146
+ minor_code = unit_code[0:3]
147
+ sub_major_code = unit_code[0:2]
148
+ major_code = unit_code[0]
149
+ isco_hierarchy[unit_code] = {minor_code, major_code, sub_major_code}
150
+
151
+ return isco_hierarchy
152
+
153
+ def find_ancestors(node: str, hierarchy: dict) -> set:
154
+ """
155
+ Find the ancestors of a given node in a hierarchy.
156
+
157
+ Args:
158
+ node (str): The node for which to find ancestors.
159
+ hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
160
+
161
+ Returns:
162
+ set: A set of ancestors of the given node.
163
+ """
164
+ ancestors = set()
165
+ nodes_to_visit = [node]
166
+ while nodes_to_visit:
167
+ current_node = nodes_to_visit.pop()
168
+ if current_node in hierarchy:
169
+ parents = hierarchy[current_node]
170
+ ancestors.update(parents)
171
+ nodes_to_visit.extend(parents)
172
+ return ancestors
173
+
174
+ def extend_with_ancestors(self, classes: set, hierarchy: dict) -> set:
175
+ """
176
+ Extend the given set of classes with their ancestors from the hierarchy.
177
+
178
+ Args:
179
+ classes (set): The set of classes to extend.
180
+ hierarchy (dict): The hierarchy of classes.
181
+
182
+ Returns:
183
+ set: The extended set of classes including their ancestors.
184
+ """
185
+ extended_classes = set(classes)
186
+ for cls in classes:
187
+ ancestors = self.find_ancestors(cls, hierarchy)
188
+ extended_classes.update(ancestors)
189
+ return extended_classes
190
+
191
+ def calculate_hierarchical_precision_recall(
192
+ reference_codes: List[str],
193
+ predicted_codes: List[str],
194
+ hierarchy: Dict[str, Set[str]],
195
+ ) -> Tuple[float, float]:
196
+ """
197
+ Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
198
+
199
+ Args:
200
+ real_codes (List[str]): The list of reference codes.
201
+ predicted_codes (List[str]): The list of predicted codes.
202
+ hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
203
+
204
+ Returns:
205
+ Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
206
+ """
207
+ # Extend the sets of real and predicted codes with their ancestors
208
+ extended_real = set()
209
+ for code in reference_codes:
210
+ extended_real.add(code)
211
+ extended_real.update(hierarchy.get(code, set()))
212
+
213
+ extended_predicted = set()
214
+ for code in predicted_codes:
215
+ extended_predicted.add(code)
216
+ extended_predicted.update(hierarchy.get(code, set()))
217
+
218
+ # Calculate the intersection
219
+ correct_predictions = extended_real.intersection(extended_predicted)
220
+
221
+ # Calculate hierarchical precision and recall
222
+ hP = (
223
+ len(correct_predictions) / len(extended_predicted)
224
+ if extended_predicted
225
+ else 0
226
+ )
227
+ hR = len(correct_predictions) / len(extended_real) if extended_real else 0
228
+
229
+ return hP, hR
230
+
231
+ def hierarchical_f_measure(hP, hR, beta=1.0):
232
+ """
233
+ Calculate the hierarchical F-measure.
234
+
235
+ Parameters:
236
+ hP (float): The hierarchical precision.
237
+ hR (float): The hierarchical recall.
238
+ beta (float, optional): The beta value for F-measure calculation. Default is 1.0.
239
+
240
+ Returns:
241
+ float: The hierarchical F-measure.
242
+ """
243
+ if hP + hR == 0:
244
+ return 0
245
+ return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)
246
+
247
  def _download_and_prepare(self, dl_manager):
248
  """Download external ISCO-08 csv file from the ILO website for creating the hierarchy dictionary."""
249
  isco_csv = dl_manager.download_and_extract(ISCO_CSV_MIRROR_URL)
250
  print(f"ISCO CSV file downloaded")
251
+ # self.isco_hierarchy = isco.create_hierarchy_dict(isco_csv)
252
+ self.isco_hierarchy = self.create_hierarchy_dict(isco_csv)
253
  print("ISCO hierarchy dictionary created")
254
  print(self.isco_hierarchy)
255
 
 
267
 
268
  # Calculate hierarchical precision, recall and f-measure
269
  hierarchy = self.isco_hierarchy
270
+ hP, hR = self.calculate_hierarchical_precision_recall(
271
  references, predictions, hierarchy
272
  )
273
+ hF = self.hierarchical_f_measure(hP, hR)
274
  print(
275
  f"Hierarchical Precision: {hP}, Hierarchical Recall: {hR}, Hierarchical F-measure: {hF}"
276
  )