danieldux commited on
Commit
020c3bd
1 Parent(s): e908d09

Refactor hierarchical precision and recall calculation***

Browse files
Files changed (1) hide show
  1. ham.py +46 -33
ham.py CHANGED
@@ -1,40 +1,53 @@
1
- def ancestors(class_label, hierarchy):
2
- """Return all ancestors of a given class label, excluding the root."""
3
- if class_label not in hierarchy or not hierarchy[class_label]:
4
- return set()
5
- else:
6
- # Recursively get all ancestors for each parent
7
- anc = set(hierarchy[class_label])
8
- for parent in hierarchy[class_label]:
9
- anc.update(ancestors(parent, hierarchy))
10
- return anc
11
-
12
-
13
- def extend_with_ancestors(class_labels, hierarchy):
14
- """Extend a set of class labels with their ancestors."""
15
- extended_set = set(class_labels)
16
- for label in class_labels:
17
- extended_set.update(ancestors(label, hierarchy))
18
- return extended_set
19
-
20
-
21
- def hierarchical_precision_recall(true_labels, predicted_labels, hierarchy):
22
- """Calculate hierarchical precision and recall."""
23
- true_extended = [extend_with_ancestors(ci, hierarchy) for ci in true_labels]
24
- predicted_extended = [
25
- extend_with_ancestors(c_prime_i, hierarchy) for c_prime_i in predicted_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  ]
27
 
28
- intersect_sum = sum(
29
- len(ci & c_prime_i) for ci, c_prime_i in zip(true_extended, predicted_extended)
30
- )
31
- predicted_sum = sum(len(c_prime_i) for c_prime_i in predicted_extended)
32
- true_sum = sum(len(ci) for ci in true_extended)
33
 
34
- hP = intersect_sum / predicted_sum if predicted_sum > 0 else 0
35
- hR = intersect_sum / true_sum if true_sum > 0 else 0
 
36
 
37
- return hP, hR
38
 
39
 
40
  def hierarchical_f_measure(hP, hR, beta=1.0):
 
1
+ def find_ancestors(tree, code):
2
+ """
3
+ Recursively finds ancestors of a given class (e.g., an ISCO-08 code) in a hierarchical JSON structure.
4
+
5
+ Args:
6
+ - tree: A dictionary representing the hierarchical structure.
7
+ - code: A string representing the label of the class.
8
+
9
+ Returns:
10
+ - A list of strings, each representing an ancestor of the input class.
11
+ """
12
+ ancestors = []
13
+ current = code
14
+ while current:
15
+ parent = tree[current]["parent"]
16
+ if parent:
17
+ ancestors.append(parent)
18
+ current = parent
19
+ return ancestors
20
+
21
+
22
+ def calculate_hierarchical_measures(true_labels, predicted_labels, tree):
23
+ """
24
+ Calculates hierarchical precision, recall, and F-measure in a hierarchical structure.
25
+
26
+ Args:
27
+ - true_labels: A list of strings representing true class labels.
28
+ - predicted_labels: A list of strings representing predicted class labels.
29
+ - tree: A dictionary representing the hierarchical structure.
30
+
31
+ Returns:
32
+ - hP: A floating point number representing hierarchical precision.
33
+ - hR: A floating point number representing hierarchical recall.
34
+ - hF: A floating point number representing hierarchical F-measure.
35
+ """
36
+
37
+ extended_true = [set(find_ancestors(tree, code) | {code}) for code in true_labels]
38
+ extended_pred = [
39
+ set(find_ancestors(tree, code) | {code}) for code in predicted_labels
40
  ]
41
 
42
+ true_positive = sum(len(t & p) for t, p in zip(extended_true, extended_pred))
43
+ predicted = sum(len(p) for p in extended_pred)
44
+ actual = sum(len(t) for t in extended_true)
 
 
45
 
46
+ hP = true_positive / predicted if predicted else 0
47
+ hR = true_positive / actual if actual else 0
48
+ hF = (2 * hP * hR) / (hP + hR) if (hP + hR) else 0
49
 
50
+ return hP, hR, hF
51
 
52
 
53
  def hierarchical_f_measure(hP, hR, beta=1.0):