|
def ancestors(class_label, hierarchy): |
|
"""Return all ancestors of a given class label, excluding the root.""" |
|
if class_label not in hierarchy or not hierarchy[class_label]: |
|
return set() |
|
else: |
|
|
|
anc = set(hierarchy[class_label]) |
|
for parent in hierarchy[class_label]: |
|
anc.update(ancestors(parent, hierarchy)) |
|
return anc |
|
|
|
|
|
def extend_with_ancestors(class_labels, hierarchy): |
|
"""Extend a set of class labels with their ancestors.""" |
|
extended_set = set(class_labels) |
|
for label in class_labels: |
|
extended_set.update(ancestors(label, hierarchy)) |
|
return extended_set |
|
|
|
|
|
def hierarchical_precision_recall(true_labels, predicted_labels, hierarchy): |
|
"""Calculate hierarchical precision and recall.""" |
|
true_extended = [extend_with_ancestors(ci, hierarchy) for ci in true_labels] |
|
predicted_extended = [ |
|
extend_with_ancestors(c_prime_i, hierarchy) for c_prime_i in predicted_labels |
|
] |
|
|
|
intersect_sum = sum( |
|
len(ci & c_prime_i) for ci, c_prime_i in zip(true_extended, predicted_extended) |
|
) |
|
predicted_sum = sum(len(c_prime_i) for c_prime_i in predicted_extended) |
|
true_sum = sum(len(ci) for ci in true_extended) |
|
|
|
hP = intersect_sum / predicted_sum if predicted_sum > 0 else 0 |
|
hR = intersect_sum / true_sum if true_sum > 0 else 0 |
|
|
|
return hP, hR |
|
|
|
|
|
def hierarchical_f_measure(hP, hR, beta=1.0): |
|
"""Calculate the hierarchical F-measure.""" |
|
if hP + hR == 0: |
|
return 0 |
|
return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) |
|
|