Move functions to main class
Browse files- isco_hierarchical_accuracy.py +140 -5
isco_hierarchical_accuracy.py
CHANGED
@@ -13,10 +13,12 @@
|
|
13 |
# limitations under the License.
|
14 |
"""ISCO-08 Hierarchical Accuracy Measure."""
|
15 |
|
|
|
16 |
import evaluate
|
17 |
import datasets
|
18 |
-
|
19 |
-
import
|
|
|
20 |
|
21 |
|
22 |
# TODO: Add BibTeX citation
|
@@ -110,11 +112,144 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
110 |
reference_urls=["http://path.to.reference.url/new_module"],
|
111 |
)
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
def _download_and_prepare(self, dl_manager):
|
114 |
"""Download external ISCO-08 csv file from the ILO website for creating the hierarchy dictionary."""
|
115 |
isco_csv = dl_manager.download_and_extract(ISCO_CSV_MIRROR_URL)
|
116 |
print(f"ISCO CSV file downloaded")
|
117 |
-
self.isco_hierarchy = isco.create_hierarchy_dict(isco_csv)
|
|
|
118 |
print("ISCO hierarchy dictionary created")
|
119 |
print(self.isco_hierarchy)
|
120 |
|
@@ -132,10 +267,10 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
132 |
|
133 |
# Calculate hierarchical precision, recall and f-measure
|
134 |
hierarchy = self.isco_hierarchy
|
135 |
-
hP, hR =
|
136 |
references, predictions, hierarchy
|
137 |
)
|
138 |
-
hF =
|
139 |
print(
|
140 |
f"Hierarchical Precision: {hP}, Hierarchical Recall: {hR}, Hierarchical F-measure: {hF}"
|
141 |
)
|
|
|
13 |
# limitations under the License.
|
14 |
"""ISCO-08 Hierarchical Accuracy Measure."""
|
15 |
|
16 |
+
from typing import List, Set, Dict, Tuple
|
17 |
import evaluate
|
18 |
import datasets
|
19 |
+
|
20 |
+
# import ham
|
21 |
+
# import isco
|
22 |
|
23 |
|
24 |
# TODO: Add BibTeX citation
|
|
|
112 |
reference_urls=["http://path.to.reference.url/new_module"],
|
113 |
)
|
114 |
|
115 |
+
def create_hierarchy_dict(file: str) -> dict:
|
116 |
+
"""
|
117 |
+
Creates a dictionary where keys are nodes and values are sets of parent nodes representing the group level hierarchy of the ISCO-08 structure.
|
118 |
+
The function assumes that the input CSV file has a column named 'unit' with the 4-digit ISCO-08 codes.
|
119 |
+
A csv file with the ISCO-08 structure can be downloaded from the International Labour Organization (ILO) at [https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08 EN.csv](https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08%20EN.csv)
|
120 |
+
|
121 |
+
Args:
|
122 |
+
- file: A string representing the path to the CSV file containing the 4-digit ISCO-08 codes. It can be a local path or a web URL.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
- A dictionary where keys are ISCO-08 unit codes and values are sets of their parent codes.
|
126 |
+
"""
|
127 |
+
|
128 |
+
try:
|
129 |
+
import requests
|
130 |
+
import csv
|
131 |
+
except ImportError as error:
|
132 |
+
raise error
|
133 |
+
|
134 |
+
isco_hierarchy = {}
|
135 |
+
|
136 |
+
if file.startswith("http://") or file.startswith("https://"):
|
137 |
+
response = requests.get(file)
|
138 |
+
lines = response.text.splitlines()
|
139 |
+
else:
|
140 |
+
with open(file, newline="") as csvfile:
|
141 |
+
lines = csvfile.readlines()
|
142 |
+
|
143 |
+
reader = csv.DictReader(lines)
|
144 |
+
for row in reader:
|
145 |
+
unit_code = row["unit"].zfill(4)
|
146 |
+
minor_code = unit_code[0:3]
|
147 |
+
sub_major_code = unit_code[0:2]
|
148 |
+
major_code = unit_code[0]
|
149 |
+
isco_hierarchy[unit_code] = {minor_code, major_code, sub_major_code}
|
150 |
+
|
151 |
+
return isco_hierarchy
|
152 |
+
|
153 |
+
def find_ancestors(node: str, hierarchy: dict) -> set:
|
154 |
+
"""
|
155 |
+
Find the ancestors of a given node in a hierarchy.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
node (str): The node for which to find ancestors.
|
159 |
+
hierarchy (dict): A dictionary representing the hierarchy, where the keys are nodes and the values are their parents.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
set: A set of ancestors of the given node.
|
163 |
+
"""
|
164 |
+
ancestors = set()
|
165 |
+
nodes_to_visit = [node]
|
166 |
+
while nodes_to_visit:
|
167 |
+
current_node = nodes_to_visit.pop()
|
168 |
+
if current_node in hierarchy:
|
169 |
+
parents = hierarchy[current_node]
|
170 |
+
ancestors.update(parents)
|
171 |
+
nodes_to_visit.extend(parents)
|
172 |
+
return ancestors
|
173 |
+
|
174 |
+
def extend_with_ancestors(self, classes: set, hierarchy: dict) -> set:
|
175 |
+
"""
|
176 |
+
Extend the given set of classes with their ancestors from the hierarchy.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
classes (set): The set of classes to extend.
|
180 |
+
hierarchy (dict): The hierarchy of classes.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
set: The extended set of classes including their ancestors.
|
184 |
+
"""
|
185 |
+
extended_classes = set(classes)
|
186 |
+
for cls in classes:
|
187 |
+
ancestors = self.find_ancestors(cls, hierarchy)
|
188 |
+
extended_classes.update(ancestors)
|
189 |
+
return extended_classes
|
190 |
+
|
191 |
+
def calculate_hierarchical_precision_recall(
|
192 |
+
reference_codes: List[str],
|
193 |
+
predicted_codes: List[str],
|
194 |
+
hierarchy: Dict[str, Set[str]],
|
195 |
+
) -> Tuple[float, float]:
|
196 |
+
"""
|
197 |
+
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
real_codes (List[str]): The list of reference codes.
|
201 |
+
predicted_codes (List[str]): The list of predicted codes.
|
202 |
+
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
206 |
+
"""
|
207 |
+
# Extend the sets of real and predicted codes with their ancestors
|
208 |
+
extended_real = set()
|
209 |
+
for code in reference_codes:
|
210 |
+
extended_real.add(code)
|
211 |
+
extended_real.update(hierarchy.get(code, set()))
|
212 |
+
|
213 |
+
extended_predicted = set()
|
214 |
+
for code in predicted_codes:
|
215 |
+
extended_predicted.add(code)
|
216 |
+
extended_predicted.update(hierarchy.get(code, set()))
|
217 |
+
|
218 |
+
# Calculate the intersection
|
219 |
+
correct_predictions = extended_real.intersection(extended_predicted)
|
220 |
+
|
221 |
+
# Calculate hierarchical precision and recall
|
222 |
+
hP = (
|
223 |
+
len(correct_predictions) / len(extended_predicted)
|
224 |
+
if extended_predicted
|
225 |
+
else 0
|
226 |
+
)
|
227 |
+
hR = len(correct_predictions) / len(extended_real) if extended_real else 0
|
228 |
+
|
229 |
+
return hP, hR
|
230 |
+
|
231 |
+
def hierarchical_f_measure(hP, hR, beta=1.0):
|
232 |
+
"""
|
233 |
+
Calculate the hierarchical F-measure.
|
234 |
+
|
235 |
+
Parameters:
|
236 |
+
hP (float): The hierarchical precision.
|
237 |
+
hR (float): The hierarchical recall.
|
238 |
+
beta (float, optional): The beta value for F-measure calculation. Default is 1.0.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
float: The hierarchical F-measure.
|
242 |
+
"""
|
243 |
+
if hP + hR == 0:
|
244 |
+
return 0
|
245 |
+
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)
|
246 |
+
|
247 |
def _download_and_prepare(self, dl_manager):
|
248 |
"""Download external ISCO-08 csv file from the ILO website for creating the hierarchy dictionary."""
|
249 |
isco_csv = dl_manager.download_and_extract(ISCO_CSV_MIRROR_URL)
|
250 |
print(f"ISCO CSV file downloaded")
|
251 |
+
# self.isco_hierarchy = isco.create_hierarchy_dict(isco_csv)
|
252 |
+
self.isco_hierarchy = self.create_hierarchy_dict(isco_csv)
|
253 |
print("ISCO hierarchy dictionary created")
|
254 |
print(self.isco_hierarchy)
|
255 |
|
|
|
267 |
|
268 |
# Calculate hierarchical precision, recall and f-measure
|
269 |
hierarchy = self.isco_hierarchy
|
270 |
+
hP, hR = self.calculate_hierarchical_precision_recall(
|
271 |
references, predictions, hierarchy
|
272 |
)
|
273 |
+
hF = self.hierarchical_f_measure(hP, hR)
|
274 |
print(
|
275 |
f"Hierarchical Precision: {hP}, Hierarchical Recall: {hR}, Hierarchical F-measure: {hF}"
|
276 |
)
|