Refactor ISCO_Hierarchical_Accuracy class to use weighted hierarchy dictionary
Browse files- isco_hierarchical_accuracy.py +42 -25
isco_hierarchical_accuracy.py
CHANGED
@@ -114,15 +114,14 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
114 |
|
115 |
def create_hierarchy_dict(self, file: str) -> dict:
|
116 |
"""
|
117 |
-
Creates a dictionary where keys are nodes and values are
|
118 |
-
|
119 |
-
A csv file with the ISCO-08 structure can be downloaded from the International Labour Organization (ILO) at [https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08 EN.csv](https://www.ilo.org/ilostat-files/ISCO/newdocs-08-2021/ISCO-08/ISCO-08%20EN.csv)
|
120 |
|
121 |
Args:
|
122 |
- file: A string representing the path to the CSV file containing the 4-digit ISCO-08 codes. It can be a local path or a web URL.
|
123 |
|
124 |
Returns:
|
125 |
-
- A dictionary where keys are ISCO-08 unit codes and values are
|
126 |
"""
|
127 |
|
128 |
try:
|
@@ -146,7 +145,12 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
146 |
minor_code = unit_code[0:3]
|
147 |
sub_major_code = unit_code[0:2]
|
148 |
major_code = unit_code[0]
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
return isco_hierarchy
|
152 |
|
@@ -192,40 +196,53 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
|
|
192 |
self,
|
193 |
reference_codes: List[str],
|
194 |
predicted_codes: List[str],
|
195 |
-
hierarchy: Dict[str,
|
196 |
) -> Tuple[float, float]:
|
197 |
"""
|
198 |
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
|
199 |
|
200 |
Args:
|
201 |
-
|
202 |
predicted_codes (List[str]): The list of predicted codes.
|
203 |
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
|
204 |
|
205 |
Returns:
|
206 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
207 |
"""
|
208 |
-
|
209 |
-
extended_real = set()
|
210 |
-
for code in reference_codes:
|
211 |
-
extended_real.add(code)
|
212 |
-
extended_real.update(hierarchy.get(code, set()))
|
213 |
|
214 |
-
|
215 |
-
for code in
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
218 |
|
219 |
-
|
220 |
-
correct_predictions = extended_real.intersection(extended_predicted)
|
221 |
|
222 |
-
#
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
return hP, hR
|
231 |
|
|
|
114 |
|
115 |
def create_hierarchy_dict(self, file: str) -> dict:
|
116 |
"""
|
117 |
+
Creates a dictionary where keys are nodes and values are dictionaries of their parent nodes with distance as weights,
|
118 |
+
representing the group level hierarchy of the ISCO-08 structure.
|
|
|
119 |
|
120 |
Args:
|
121 |
- file: A string representing the path to the CSV file containing the 4-digit ISCO-08 codes. It can be a local path or a web URL.
|
122 |
|
123 |
Returns:
|
124 |
+
- A dictionary where keys are ISCO-08 unit codes and values are dictionaries of their parent codes with distances.
|
125 |
"""
|
126 |
|
127 |
try:
|
|
|
145 |
minor_code = unit_code[0:3]
|
146 |
sub_major_code = unit_code[0:2]
|
147 |
major_code = unit_code[0]
|
148 |
+
|
149 |
+
# Assign weights, higher for closer ancestors
|
150 |
+
weights = {minor_code: 0.75, sub_major_code: 0.5, major_code: 0.25}
|
151 |
+
|
152 |
+
# Store ancestors with their weights
|
153 |
+
isco_hierarchy[unit_code] = weights
|
154 |
|
155 |
return isco_hierarchy
|
156 |
|
|
|
196 |
self,
|
197 |
reference_codes: List[str],
|
198 |
predicted_codes: List[str],
|
199 |
+
hierarchy: Dict[str, Dict[str, float]],
|
200 |
) -> Tuple[float, float]:
|
201 |
"""
|
202 |
Calculates the hierarchical precision and recall given the reference codes, predicted codes, and hierarchy definition.
|
203 |
|
204 |
Args:
|
205 |
+
reference_codes (List[str]): The list of reference codes.
|
206 |
predicted_codes (List[str]): The list of predicted codes.
|
207 |
hierarchy (Dict[str, Set[str]]): The hierarchy definition where keys are nodes and values are sets of parent nodes.
|
208 |
|
209 |
Returns:
|
210 |
Tuple[float, float]: A tuple containing the hierarchical precision and recall floating point values.
|
211 |
"""
|
212 |
+
extended_real = {}
|
|
|
|
|
|
|
|
|
213 |
|
214 |
+
# Extend the sets of reference codes with their ancestors
|
215 |
+
for code in reference_codes:
|
216 |
+
weight = 1.0 # Full weight for exact match
|
217 |
+
extended_real[code] = weight
|
218 |
+
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
219 |
+
extended_real[ancestor] = max(
|
220 |
+
extended_real.get(ancestor, 0), ancestor_weight
|
221 |
+
)
|
222 |
|
223 |
+
extended_predicted = {}
|
|
|
224 |
|
225 |
+
# Extend the sets of predicted codes with their ancestors
|
226 |
+
for code in predicted_codes:
|
227 |
+
weight = 1.0
|
228 |
+
extended_predicted[code] = weight
|
229 |
+
for ancestor, ancestor_weight in hierarchy.get(code, {}).items():
|
230 |
+
extended_predicted[ancestor] = max(
|
231 |
+
extended_predicted.get(ancestor, 0), ancestor_weight
|
232 |
+
)
|
233 |
+
|
234 |
+
# Calculate weighted correct predictions
|
235 |
+
correct_weights = 0
|
236 |
+
for code, weight in extended_predicted.items():
|
237 |
+
if code in extended_real:
|
238 |
+
correct_weights += min(weight, extended_real[code])
|
239 |
+
|
240 |
+
total_predicted_weights = sum(extended_predicted.values())
|
241 |
+
total_real_weights = sum(extended_real.values())
|
242 |
+
|
243 |
+
# Calculate hierarchical precision and recall using weighted sums
|
244 |
+
hP = correct_weights / total_predicted_weights if total_predicted_weights else 0
|
245 |
+
hR = correct_weights / total_real_weights if total_real_weights else 0
|
246 |
|
247 |
return hP, hR
|
248 |
|