def ancestors(class_label, hierarchy): """Return all ancestors of a given class label, excluding the root.""" if class_label not in hierarchy or not hierarchy[class_label]: return set() else: # Recursively get all ancestors for each parent anc = set(hierarchy[class_label]) for parent in hierarchy[class_label]: anc.update(ancestors(parent, hierarchy)) return anc def extend_with_ancestors(class_labels, hierarchy): """Extend a set of class labels with their ancestors.""" extended_set = set(class_labels) for label in class_labels: extended_set.update(ancestors(label, hierarchy)) return extended_set def hierarchical_precision_recall(true_labels, predicted_labels, hierarchy): """Calculate hierarchical precision and recall.""" true_extended = [extend_with_ancestors(ci, hierarchy) for ci in true_labels] predicted_extended = [ extend_with_ancestors(c_prime_i, hierarchy) for c_prime_i in predicted_labels ] intersect_sum = sum( len(ci & c_prime_i) for ci, c_prime_i in zip(true_extended, predicted_extended) ) predicted_sum = sum(len(c_prime_i) for c_prime_i in predicted_extended) true_sum = sum(len(ci) for ci in true_extended) hP = intersect_sum / predicted_sum if predicted_sum > 0 else 0 hR = intersect_sum / true_sum if true_sum > 0 else 0 return hP, hR def hierarchical_f_measure(hP, hR, beta=1.0): """Calculate the hierarchical F-measure.""" if hP + hR == 0: return 0 return (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)